In [4]:
import torch
from datasets import load_from_disk

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
dataset = load_from_disk("tokenised_wmt14")
dataset = dataset.with_format("torch") # turns the lists to torch tensors

train_dataset = dataset["train"]
test_dataset = dataset["test"]
val_dataset = dataset["validation"]

# Reduce Training dataset by half
half = len(train_dataset) // 2
train_dataset = train_dataset.select(range(half))



In [7]:
from transformers import PreTrainedTokenizerFast

tokeniser = PreTrainedTokenizerFast.from_pretrained("./trained_tokeniser") 


data1 = train_dataset[1]
print(data1)
input_ids_list = data1["labels"]
print(tokeniser.decode(input_ids_list))


{'input_ids': tensor([    2,  1775,  1245, 13027,  1007,  2988,  1126,  1212,  3902,  1033,
         1024,  3823,  1055,  1024,  2893,  2995, 17793,  3085,  1212, 24621,
        28891,     3]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([ 1754,  1948,  3633,  1074,  7316,  1484,  1500,  3747,  1165, 17168,
         1644, 17319,  1033,  1156,  4890,  2851,  1079,     3])}
Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den n√§chsten Tagen.</s>


In [9]:
print(train_dataset)
print(test_dataset)
print(val_dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1224808
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1969
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2062
})


In [15]:
# model definition:
import torch.nn as nn

# Note torch bilstm's process all timesteps of a sequence at once.

class Bilstm_Encoder(nn.Module):
    def __init__(self,vocab_size,embedding_dim,hidden_size,embedding_matrix):
        super().__init__()
        # input (B,L)
        self.embedding_matrix = embedding_matrix # (B,L,emb)
        self.bilstm = nn.LSTM(input_size=embedding_dim,num_layers=2, hidden_size= hidden_size, bidirectional=True, batch_first=True) # takes in (B,L,emb)
        self.bridge_h = nn.Linear(hidden_size*2,hidden_size)
        self.bridge_c = nn.Linear(hidden_size*2,hidden_size)

    def forward(self,input): # (B,L)
        embedded = self.embedding_matrix(input) # Returns (B,L,embedding_dim) 
        output, (h_n,c_n) = self.bilstm(embedded)
        # output: (B,L,2*hidden_size) outputs h_t fwd backward concatenated for the top layer, h_n = (D*num_layers,B,hidden_size) outputs h_n fwd, backward for every layer
        # Layer 1: concatenate fwd and backward h_0,h_n,c_0,c_n
        h_l1 = torch.cat([h_n[0], h_n[1]], dim=1) # (B, enc_hid*2)
        c_l1 = torch.cat([c_n[0], c_n[1]], dim=1) # (B, enc_hid*2)
        # Layer 2:
        h_l2 = torch.cat([h_n[2], h_n[3]], dim=1) # (B, enc_hid*2)
        c_l2 = torch.cat([c_n[2], c_n[3]], dim=1) # (B, enc_hid*2)

        decoder_h_0 = torch.tanh(torch.stack([self.bridge_h(h_l1), self.bridge_h(h_l2)], dim=0)) # (2,B,hidden_size)
        decoder_c_0 = torch.tanh(torch.stack([self.bridge_c(c_l1), self.bridge_c(c_l2)], dim=0)) # (2,B,hidden_size)
        return output, decoder_h_0, decoder_c_0

# compute f(hi,sj) for all hi, then softmax over.
class Luong_attention(nn.Module):
    def __init__(self,encoder_dim,decoder_dim): # Output = C_i = (B,2*hidden_size) 
        super().__init__()
        self.W = nn.Parameter(torch.FloatTensor(
            decoder_dim, encoder_dim).uniform_(-0.1, 0.1)) # (decoder,encoder)
    
     # query @ W @ values^T 
    def forward(self,query,values,attention_mask): # query:(B,decoder),values: (B,L,encoder_dim), attention_mask: (B,L)
        transformed_query = query @ self.W #  (B,dec)@(dec,enc) = (B,enc)
        attention_weights = transformed_query.unsqueeze(1) @ values.transpose(1,2)  # (B,1,enc)@(B,encoder_dim,L) = (B,1,L)

        attention_mask = attention_mask.unsqueeze(1) # (B,1,L)
        attention_weights = attention_weights.masked_fill(attention_mask == 0, float('-inf')) # remove any terms with paddings.

        attention_scores = torch.softmax(attention_weights,dim=-1) # (B,1,L)

        context_vector = (attention_scores @ values).squeeze(1) # (B,encoder_dim)
        return context_vector

class lstm_Decoder(nn.Module):
    def __init__(self,embedding_size,encoder_hidden_size,hidden_size,vocab,embedding_matrix):
        super().__init__()
        self.embedding = embedding_matrix # nn.Embedding(vocab_size,embedding_dim)
        self.attention = Luong_attention(encoder_hidden_size*2,hidden_size)
        self.lstm = nn.LSTM(input_size=embedding_size, num_layers=2,hidden_size= hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size + encoder_hidden_size*2,vocab)
    
    def forward(self,decoder_h_0,decoder_c_0,Encoder_Output,max_length,attention_mask,target_tensor=None):
        device = decoder_h_0.device
        batch_size = Encoder_Output.shape[0]
        eos_token_id = 3
        sos_token_id = 2

        finished = torch.zeros(batch_size, dtype=torch.bool, device=device).unsqueeze(1) # (B,1)
        input = torch.full((batch_size, 1), sos_token_id,device=device) # (B,1)
        starting_input = self.embedding(input) # (B,1,embedding_size)
        curr_input = starting_input # (B,1,embedding_size)
        curr_h = decoder_h_0
        curr_c = decoder_c_0
        logits = []

        if target_tensor is not None: # training mode          
            # Target_tensor (B,max_padding_len)
            max_length = target_tensor.shape[1] # Change max_len to max in label batch

        for i in range(max_length):
            # encoder_output: (B,L,2*encoder_dim)
            prediction, output, (h_n,c_n) = self.forward_step(curr_input,curr_h,curr_c,Encoder_Output,attention_mask)
            logits.append(prediction) # prediction (B,vocab)
            curr_h = h_n
            curr_c = c_n

            values,indices = torch.topk(prediction,1,dim=1) # (B,1)
            finished |= (indices == eos_token_id) # (B,1) keeps track of which sequences are finished

            if target_tensor is None: # inference
                curr_input = self.embedding(indices) # (B,1,embedding)
                if finished.all():
                    break
                
            else: # teacher forcing training 
                # target_tensor (B,pad_length)
                curr_input = self.embedding(target_tensor[:,i].unsqueeze(1)) # (B,1,embedding)
        
        return torch.stack(logits,dim=1) # (B,L,vocab)

    def forward_step(self,input,h_n,c_n,Encoder_Output,attention_mask):
        # input: (B,1,input_size)

        output,(h_n,c_n) = self.lstm(input,(h_n,c_n)) # output (B,1,dec_dim)
        query = output.squeeze(1) # (B,decode_dim)
        context_n = self.attention(query,Encoder_Output,attention_mask) # (B,enc*2)

        h_i_context_n = torch.cat([query,context_n],dim=1) # (B,enc*3)
        prediction = self.output(h_i_context_n) # (B,vocab)

        return prediction,output,(h_n,c_n)
        

class seq2seq_bilstm(nn.Module):
    def __init__(self,vocab_size,embedding_dim,encoder_hidden,decoder_hidden): # Output = C_i = (B,2*hidden_size) 
        super().__init__()

        self.embedding_matrix = nn.Embedding(vocab_size,embedding_dim) # (B,L,emb) shared embedding matrix
        self.encoder = Bilstm_Encoder(vocab_size,embedding_dim,encoder_hidden,self.embedding_matrix)
        self.decoder = lstm_Decoder(embedding_dim,encoder_hidden,decoder_hidden,vocab_size,self.embedding_matrix)

    def forward(self,input,max_length,attention_mask,target_tensor=None):
        output, decoder_h_0, decoder_c_0 = self.encoder(input)
        Logits = self.decoder(decoder_h_0,decoder_c_0,output,max_length,attention_mask,target_tensor) # (B,L,Vocab)
        
        return Logits


In [16]:
# Model Parameters:
max_length = 128
hidden_size = 128
vocab_size = 32000
embedding_dim = 128
batch_size = 32
lr = 1e-3
val_rate = 1000
num_epochs = 20

# special Tokens
unk = 0
pad = 1
sos = 2
eos = 3

In [None]:
def batch_process(batch,model,loss_fn,teacher_force =True):

    target_input = None
    input_ids = batch["input_ids"].to(device) # (B,L)
    attention_mask = batch["attention_mask"].to(device) # (B,L)
    labels = batch["labels"].to(device) # (B,label_L)

    if teacher_force == True:    
        # replace labels -100 with 1
        target_input = labels.clone()
        target_input[target_input == -100] = pad

    Logits = model(input_ids,max_length,attention_mask,target_input)
    B, L, V = Logits.shape

    print(Logits.shape)
    print(labels.shape)
    loss = loss_fn(
        Logits.view(B*L, V),
        labels.view(B*L)
    )
    return loss

In [None]:
# The training loop
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

# prepare collate_fn
collate_fn = DataCollatorForSeq2Seq(tokeniser,padding=True)

#prepare the dataloaders:
train_dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=batch_size,collate_fn=collate_fn)
val_dataloader = DataLoader(dataset["validation"], shuffle=True, batch_size=batch_size,collate_fn=collate_fn)

model = seq2seq_bilstm(vocab_size,embedding_dim,hidden_size,hidden_size)
model.to(device)
model.zero_grad()
  
all_losses = []
all_val_losses = []
model.train()

optimiser = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in tqdm(range(num_epochs)):
    batch_counter = 0
    for batch in train_dataloader:
        batch_counter += 1
        loss = batch_process(batch,model,loss_fn,teacher_force=True)
        
        # all_losses.append(loss)

        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

        if batch_counter % val_rate == 0: # Every 1000 batches
            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for batch in val_dataloader:                            
                    total_val_loss += batch_process(batch,model,loss_fn,teacher_force=True) # Teacher forcing for val at the end apply BLEU

            all_val_losses.append(total_val_loss/len(val_dataloader))
            model.train()


#Labels shape is different since its matched with the one to one of the decoders output.



  0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([32, 128, 32000])
torch.Size([32, 30])





RuntimeError: shape '[4096]' is invalid for input of size 960

In [None]:
print(all_losses)

[tensor(10.3774, device='cuda:0', grad_fn=<NllLossBackward0>)]
