In [1]:
import torch
from torch.utils.data import Dataset
import pandas as pd
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
special_token_dict = {"bos_token": "<s>"}
tokenizer.add_special_tokens(special_token_dict)

tokenizer.encode("hi my name is neet")
import torch
import torch
from torch.utils.data import Dataset
import pandas as pd

class WMTDataset(Dataset):
    
    def __init__(self, data_path, src_tokenizer, tgt_tokenizer, seq_len):
        super().__init__()
        self.data = pd.read_csv(data_path)
        self.src_vocab_size = src_tokenizer.vocab_size
        self.tgt_vocab_size = tgt_tokenizer.vocab_size
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.seq_len = seq_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sos_token = self.src_tokenizer.encode(['<s>'])[0]
        eos_token = self.src_tokenizer.encode(['</s>'])[0]
        pad_token = self.src_tokenizer.encode(['<pad>'])[0]
        src_encoding = self.src_tokenizer.encode(self.data.iloc[index]['en'])[:-1] # remove default eos token
        tgt_encoding = self.tgt_tokenizer.encode(self.data.iloc[index]['de'])[:-1] # remove default eos token
        print("len of src sen: ", len(src_encoding))
        print("len of tgt sen: ", len(tgt_encoding))
        assert len(src_encoding) < self.seq_len + 2, "sentence too big"
        assert len(tgt_encoding) < self.seq_len + 2, "sentence too big"
        
        src_padding_len = self.seq_len - (len(src_encoding) + 2)  
        tgt_padding_len = self.seq_len - (len(tgt_encoding) + 2) 
        
        src_encoding = torch.tensor([sos_token] + src_encoding + [eos_token] + [pad_token]*src_padding_len, dtype=torch.long)
        tgt_encoding = torch.tensor([sos_token] + tgt_encoding + [eos_token] + [pad_token]*tgt_padding_len, dtype=torch.long)
        
        causal_mask = torch.triu(torch.ones(self.seq_len, self.seq_len, dtype=bool), diagonal=1).to(bool)

        src_mask = (src_encoding == pad_token).unsqueeze(0)
        tgt_mask = (tgt_encoding == pad_token).unsqueeze(0)
                
        return src_encoding, tgt_encoding, src_mask, tgt_mask
    
ds = WMTDataset("wmt14_translate_de-en_test.csv", tokenizer, tokenizer, 200)



In [2]:
from torch.utils.data import DataLoader
from model import build_transformer, generate_causal_mask
from config import TransformerConfig

config = TransformerConfig()

model = build_transformer(config)

dataloader = DataLoader(ds, batch_size=2, shuffle=True)

loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids("<pad>"), label_smoothing=0.1)



In [3]:
for batch in dataloader:
    src, tgt, src_mask, tgt_mask = batch
    enc_self_attn_mask = src_mask.unsqueeze(2) | src_mask.unsqueeze(3)
    
    causal_mask = generate_causal_mask(200)
    
    dec_self_attn_mask = tgt_mask.unsqueeze(2) | tgt_mask.unsqueeze(3) | causal_mask
    
    dec_cross_attn_mask = tgt_mask.unsqueeze(3) | src_mask.unsqueeze(2)
    
    y = model(src, tgt, enc_self_attn_mask, dec_self_attn_mask, dec_cross_attn_mask)
    
    
    
    break

len of src sen:  10
len of tgt sen:  15
len of src sen:  26
len of tgt sen:  34


In [4]:
loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.convert_tokens_to_ids(tokenizer.pad_token), label_smoothing=config.label_smooting)

AttributeError: 'TransformerConfig' object has no attribute 'label_smooting'

In [5]:
tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

58100

In [9]:
from torch.utils.data import Subset

# Suppose dataset is your Dataset instance
sliced_dataset = Subset(ds, indices=range(0, 2))


In [10]:
import pandas as pd

df = pd.read_csv(config.train_data_path)

# Slice top 2 rows
df_overfit = df.head(2)

# Save to new CSV
df_overfit.to_csv("overfit.csv", index=False)

In [12]:
y = torch.rand(15, 100)
z = torch.randint(0, 5, (15,))

In [15]:
loss(-y,z)

tensor(4.7501)

In [14]:
y

tensor([[0.3709, 0.0303, 0.7093,  ..., 0.3078, 0.6270, 0.5365],
        [0.1380, 0.1782, 0.8531,  ..., 0.7949, 0.5710, 0.3463],
        [0.3684, 0.7031, 0.5402,  ..., 0.0320, 0.9080, 0.7795],
        ...,
        [0.0842, 0.5291, 0.9343,  ..., 0.4297, 0.9429, 0.2803],
        [0.7740, 0.4271, 0.4210,  ..., 0.2886, 0.3065, 0.2006],
        [0.2793, 0.2835, 0.6560,  ..., 0.7539, 0.9754, 0.7148]])

In [None]:
import torch.nn as nn

nn.Transformer()