In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput

import warnings
warnings.filterwarnings('ignore')

In [2]:
class BartEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
        self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1000)
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask
        encoder_outputs = self.model.get_encoder()(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = encoder_outputs.last_hidden_state
        return last_hidden_state, attention_mask

class BartDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
        self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
    def forward(self, last_hidden_state, attention_mask):
        decoder_input_ids = torch.full((last_hidden_state.size(0), 1), self.model.config.decoder_start_token_id, dtype=torch.long) 
        last_hidden_state = BaseModelOutput(last_hidden_state=last_hidden_state, hidden_states=None, attentions=None)
        outputs = self.model.generate(
            attention_mask=attention_mask,
            encoder_outputs=last_hidden_state,
            decoder_input_ids=decoder_input_ids,
            max_length=100,
            num_beams=5,
            # do_sample=True
            # temperature=1.0
        )
        reconstructed_text = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        return reconstructed_text

TextEncoder = BartEncoder()
TextDecoder = BartDecoder()

In [3]:
data = pd.read_csv("data.csv", encoding="utf-8")
data.head()

Unnamed: 0,product_description,product_type,sentiment
0,#techcrunch #google This Post Has Nothing to d...,9,2
1,Data is the new oil. (Companies like Google an...,3,1
2,my sister is throwing the Google sxsw party to...,3,3
3,Clear +succinct visions make for great UX (thi...,9,2
4,40% of Google Maps use is mobile marissamayer ...,9,3


In [4]:
col_name = data.columns
print(col_name)

Index(['product_description', 'product_type', 'sentiment'], dtype='object')


In [6]:
[', '.join(col_name)] * 5

['product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment']

In [7]:
col_name_emd, col_name_mask = TextEncoder([', '.join(col_name)] * 5)
print(col_name_emd.size())

torch.Size([5, 11, 1024])


In [8]:
reconstructed_col_name = TextDecoder(col_name_emd, col_name_mask)
reconstructed_col_name

['product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment',
 'product_description, product_type, sentiment']