In [14]:
import torch
from datasets import load_from_disk

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
dataset = load_from_disk("tokenised_wmt14")
dataset = dataset.with_format("torch") # turns the lists to torch tensors


In [8]:
from transformers import PreTrainedTokenizerFast

tokeniser = PreTrainedTokenizerFast.from_pretrained("./trained_tokeniser") 


data1 =dataset["train"][1]
print(data1)
input_ids_list = data1["labels"]
print(tokeniser.decode(input_ids_list))


{'input_ids': tensor([ 1775,  1245, 13027,  1007,  2988,  1126,  1212,  3902,  1033,  1024,
         3823,  1055,  1024,  2893,  2995, 17793,  3085,  1212, 24621, 28891]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([ 1754,  1948,  3633,  1074,  7316,  1484,  1500,  3747,  1165, 17168,
         1644, 17319,  1033,  1156,  4890,  2851,  1079])}
Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nÃ¤chsten Tagen.
{'input_ids': tensor([3155, 1152, 6021, 9327, 1117, 1212, 9288,   10, 1014, 7415, 7938]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([ 1397,  2984,  4790,  1255,  1139,  1500,  5063,  1373,  2520,  1463,
         1139,  1179, 12708])}


In [9]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2449617
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2062
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1969
    })
})


In [None]:
# model definition:
import torch.nn as nn

# Note torch bilstm's process all timesteps of a sequence at once.

class Bilstm_Encoder(nn.Module):
    def __init___(self,input_size,embedding_dim,hidden_size):
        super().__init__()
        # input_size (B,L)
        # embedding matrix C
        self.embedding_matrix = nn.embedding(input_size,embedding_dim)
        # Bilstm h_n.shape = (b,L,hidden_size)
        self.bilstm = nn.LSTM(input_size=embedding_dim, hidden_size= hidden_size, bidirectional=True, batch_first=True)

    def forward(self,input): # (B,L)
        embedded = self.embedding_matrix(input) # Returns (B,L,embedding_dim) 
        output, h_n = self.bilstm(embedded)
        # output: (B,L,2*hidden_size) outputs h_t fwd backward concatenated for the top layer., h_n = (D*num_layers,B,hidden_size) outputs h_n fwd, backward for every layer
        return output, h_n

# compute f(hi,sj) for all hi, then softmax over.
class Luong_attention(nn.Module):
    def __init__(self,encoder_dim,decoder_dim): # Output = C_i = (B,2*hidden_size) 
        super().__init__()
        self.decoder_dim= decoder_dim
        self.encoder_dim= encoder_dim
        self.W = nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) # (decoder,encoder)

    def forward(self,query,values): # query:(B,decoder),values: (B,L,encoder_dim)
        transpose_values = torch.transpose(values,dim0=1,dim1=2) # (B,encoder_dim,L)
        key = self.W @ transpose_values # (B,decoder,L) broadcasting

        # need to transform query from (B,decoder) -> (B,1,decoder)
        query = query.unsqueeze(1)
        attention_weights = query @ key # (B,1,L)
        attention_scores = nn.Softmax(attention_weights,dim=-1) # (B,1,L)
        context_vector = (attention_scores @ values).squeeze(1) # (B,1,L) @ (B,L,encoder_dim) = (B,1,encoder_dim)
        return context_vector

class Bilstm_Decoder(nn.Module):
    def __init___(self,input_size,hidden_size):
        super().__init__()



            







In [None]:
# The training loop

from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

# prepare collate_fn
collate_fn = DataCollatorForSeq2Seq(tokeniser,padding=True)

#prepare the dataloaders:
batch_size = 32
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=batch_size,collate_fn=collate_fn)
for batch in train_dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)
    print(input_ids.shape)
    print(attention_mask.shape)
    print(labels.shape)

    break

#Labels shape is different since its matched with the one to one of the decoders output.


torch.Size([32, 46])
torch.Size([32, 46])
torch.Size([32, 49])
