In [28]:
import random
import string

# Set seeds for reproducibility
random.seed(52)

# Define possible repeat lengths (e.g., 'aaa', 'bbbb', 'ccccccc')
repeats_range = [3, 4, 5, 7]

# Function to generate `text8` with a specified length and then apply random masks
def generate_text8_with_and_without_masks(target_length, mask_token="?"):
    text8 = []  # Start with an empty list to store sequences
    
    # Generate text sequence without any masks
    while len(' '.join(text8)) < target_length:
        for char in string.ascii_lowercase:
            repeat_count = random.choice(repeats_range)
            sequence = ' '.join([char] * repeat_count)
            text8.append(sequence)
            
            if len(' '.join(text8)) >= target_length:
                break
    
    # Join and trim to the exact target length
    unmasked_text8_str = ' '.join(text8)[:target_length]
    
    # Split the sequence into tokens and randomly mask 20% of them
    tokens = unmasked_text8_str.split()
    num_masks = int(len(tokens) * 0.2)
    mask_indices = random.sample(range(len(tokens)), num_masks)
    
    # Create a copy of tokens for the masked version
    masked_tokens = tokens[:]
    
    for idx in mask_indices:
        masked_tokens[idx] = mask_token
    
    # Join tokens back into strings
    masked_text8_str = ' '.join(masked_tokens)
    
    return unmasked_text8_str, masked_text8_str

# Generate `text8` with a target length of 50000
unmasked_text8, masked_text8 = generate_text8_with_and_without_masks(target_length=50000)

# Print the unmasked and masked versions
print("Unmasked text8:\n", unmasked_text8)
print("\nMasked text8:\n", masked_text8)


Unmasked text8:
 a a a a a b b b c c c c c c c d d d d d e e e e e e e f f f g g g g h h h h i i i i i i i j j j j k k k k k l l l l l l l m m m n n n o o o o o o o p p p p p p p q q q r r r r s s s t t t t u u u u v v v v v v v w w w w x x x x x y y y y y y y z z z z z z z a a a a a a a b b b c c c d d d d d d d e e e f f f f f f f g g g g g h h h h i i i i j j j j j j j k k k k l l l l l m m m n n n n n o o o o p p p p p p p q q q q q q q r r r s s s s t t t u u u u v v v v v v v w w w w w w w x x x x y y y y z z z a a a a b b b c c c c d d d d e e e e e f f f f g g g g g g g h h h i i i i j j j j k k k k l l l l m m m m m m m n n n n o o o o p p p p p p p q q q q r r r s s s s t t t u u u v v v v v w w w w w x x x y y y y y y y z z z a a a a b b b b c c c d d d d d d d e e e f f f f g g g g g h h h h h h h i i i j j j j k k k k l l l l m m m m n n n o o o o p p p p p q q q q r r r r r s s s s s t t t t u u u u v v v w w w w w x x x y y y z z z z a a a a b b b b b c c c c c d d d d d

In [29]:
import tqdm
import collections
import more_itertools
import wandb
import pandas as pd
import torch
import random
import string


def preprocess(text: str) -> list[str]:
  text = text.lower()
  text = text.replace('.',  ' <PERIOD> ')
  text = text.replace(',',  ' <COMMA> ')
  text = text.replace('"',  ' <QUOTATION_MARK> ')
  text = text.replace(';',  ' <SEMICOLON> ')
  text = text.replace('!',  ' <EXCLAMATION_MARK> ')
  text = text.replace('?',  ' <mask> ')
  text = text.replace('(',  ' <LEFT_PAREN> ')
  text = text.replace(')',  ' <RIGHT_PAREN> ')
  text = text.replace('--', ' <HYPHENS> ')
  text = text.replace('?',  ' <QUESTION_MARK> ')
  text = text.replace(':',  ' <COLON> ')
  words = text.split()
  stats = collections.Counter(words)
  words = [word for word in words if stats[word] > 0]
  return words

In [30]:
# with open('text8') as f: text8: str = f.read()
titles_string = ' '.join(unmasked_text8)  # Joining with a space

# Concatenate the titles string to the text8 variable
unmasked_text8 += ' ' + titles_string  # Add a space for separation

In [31]:
# with open('text8') as f: text8: str = f.read()
titles_string = ' '.join(masked_text8)  # Joining with a space

# Concatenate the titles string to the text8 variable
masked_text8 += ' ' + titles_string  # Add a space for separation

In [32]:
corpus: list[str] = preprocess(masked_text8)

# corpus: list[str] = (text8)

In [33]:
print(corpus)

['a', 'a', 'a', '<mask>', 'a', '<mask>', 'b', 'b', 'c', 'c', 'c', 'c', 'c', '<mask>', 'c', '<mask>', 'd', 'd', '<mask>', 'd', 'e', 'e', 'e', 'e', 'e', 'e', 'e', '<mask>', 'f', 'f', '<mask>', 'g', '<mask>', 'g', 'h', 'h', '<mask>', 'h', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'j', 'j', 'j', 'j', 'k', 'k', 'k', 'k', '<mask>', 'l', 'l', 'l', 'l', 'l', '<mask>', 'l', 'm', 'm', 'm', 'n', 'n', '<mask>', 'o', 'o', 'o', '<mask>', 'o', 'o', '<mask>', 'p', 'p', '<mask>', 'p', 'p', 'p', 'p', 'q', 'q', 'q', 'r', 'r', 'r', 'r', '<mask>', '<mask>', 's', 't', '<mask>', 't', 't', 'u', 'u', 'u', 'u', 'v', 'v', '<mask>', 'v', 'v', 'v', 'v', 'w', 'w', 'w', 'w', '<mask>', 'x', 'x', 'x', '<mask>', '<mask>', 'y', 'y', 'y', 'y', 'y', 'y', 'z', 'z', 'z', '<mask>', 'z', 'z', 'z', 'a', '<mask>', 'a', 'a', 'a', '<mask>', 'a', 'b', '<mask>', '<mask>', 'c', 'c', 'c', 'd', '<mask>', 'd', 'd', 'd', 'd', 'd', 'e', 'e', 'e', 'f', '<mask>', 'f', 'f', '<mask>', 'f', '<mask>', 'g', 'g', 'g', 'g', 'g', 'h', 'h', 'h', '<mask>',

In [96]:
masked_text8

'a a a ? a ? b b c c c c c ? c ? d d ? d e e e e e e e ? f f ? g ? g h h ? h i i i i i i i j j j j k k k k ? l l l l l ? l m m m n n ? o o o ? o o ? p p ? p p p p q q q r r r r ? ? s t ? t t u u u u v v ? v v v v w w w w ? x x x ? ? y y y y y y z z z ? z z z a ? a a a ? a b ? ? c c c d ? d d d d d e e e f ? f f ? f ? g g g g g h h h ? ? ? i i j ? j j ? j j k k ? k l l l l ? m m m n n ? n n o o o o p p p p p ? p q q q ? q q q r r ? s s s s t t t u u ? u ? v ? v v v v w w ? w w w w x x x x y y y y z ? ? a a a a b ? b ? c c c d d d d ? e e e e f f f f g g g g g ? g h h h i i i i j j j ? k k k k l ? ? ? m m m m m m m n n n ? ? o o o ? p p p p ? p q q q q r r r ? ? s s t t t u u u v v ? v ? w w w w ? x ? x y y y y y ? y z ? z a a ? a ? b b b c c c d d d d ? ? ? e e e ? f f f g g ? ? g h ? h ? h ? h ? i i j j j j k k k k l l l l m m m ? n n n o ? o o p p p p p q q q q r r r r r s s ? s s ? t t t u ? u u ? v v w w w w ? x x x y ? ? z z z ? a ? a ? b b ? ? b ? c ? c c d d d ? d e e e e e ? f f

In [34]:
def create_lookup_tables(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
  word_counts = collections.Counter(words)
  vocab = sorted(word_counts, key=lambda k: word_counts.get(k), reverse=True)
  int_to_vocab = {ii: word for ii, word in enumerate(vocab)}
  # int_to_vocab[0] = '<mask>'
  vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
  return vocab_to_int, int_to_vocab

In [35]:
#creating dictionary 
words_to_ids, ids_to_words = create_lookup_tables(corpus)

In [59]:
#creating token for the list we have from our dictionary 
tokens = [words_to_ids[word] for word in corpus]

In [40]:
unmasked_text8: list[str] = preprocess(unmasked_text8)

In [58]:
tokens_unmasked = [words_to_ids[word] for word in unmasked_text8]

In [61]:
tokens[:10]

[10, 10, 10, 0, 10, 0, 23, 23, 14, 14]

In [62]:
tokens_unmasked[:10]

[10, 10, 10, 10, 10, 23, 23, 23, 14, 14]

In [None]:
import torch

class SkipGramFoo(torch.nn.Module):
    def __init__(self, voc, emb, ctx):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)# Additional embedding for context
        self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
        self.linear_q = torch.nn.Linear(64,64)
        self.linear_k = torch.nn.Linear(64,64)
        self.linear_v = torch.nn.Linear(64,64)
        # Add learnable bias for the attention embeddings
        self.attn_embedding_bias = torch.nn.Parameter(torch.zeros(emb))

    def forward(self, inpt):
        cw_tensor = inpt
        emb = self.emb(cw_tensor) 
        ctx = self.emb(cw_tensor)
        
        emb = self.linear_q(emb)
        # print(ctx.shape)
        ctx = self.linear_k(emb)
        ctx = ctx.squeeze(0)
        emb = emb.squeeze(0)
        similarity_matrix = torch.matmul(emb, ctx.T)      
        size_of_dim = 64 # key dimension
        scaling_factor = size_of_dim ** 0.5       
        # print("scaling factos",scaling_factor)
        #This is normalised by the square root of the size of the dimensions
        similarity_matrix = similarity_matrix / scaling_factor
        soft_matrix = torch.nn.functional.softmax(similarity_matrix, dim=1)
        mask = torch.empty(5,5).random_(2)
        soft_matrix_masked = soft_matrix * mask     
        emb = self.linear_v(emb)
        # print(emb.shape)
        attention = torch.matmul(soft_matrix_masked, emb) + self.attn_embedding_bias  
        out = self.ffw(attention)  
        
        return out


In [None]:
import torch

class SkipGramFoo(torch.nn.Module):
    def __init__(self, voc, emb, ctx):
        super().__init__()
        # Define embeddings and feedforward layers
        self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
        self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
        
        # Linear layers for query, key, and value transformations
        self.linear_q = torch.nn.Linear(emb, emb)
        self.linear_k = torch.nn.Linear(emb, emb)
        self.linear_v = torch.nn.Linear(emb, emb)
        
        # Learnable bias for attention embeddings
        self.attn_embedding_bias = torch.nn.Parameter(torch.zeros(emb))

    def forward(self, inpt):
        # Embed the input tokens
        emb = self.emb(inpt)  # Shape: [batch_size, seq_len, emb_dim]
        
        # Transform embeddings for query, key, and value
        query = self.linear_q(emb)   # Shape: [batch_size, seq_len, emb_dim]
        key = self.linear_k(emb)     # Shape: [batch_size, seq_len, emb_dim]
        value = self.linear_v(emb)   # Shape: [batch_size, seq_len, emb_dim]
        
        similarity_matrix = torch.matmul(query, key.transpose(-2, -1))  # Shape: [batch_size, seq_len, seq_len]
        
        # Scale the similarity matrix
        scaling_factor = key.size(-1) ** 0.5  # Normalize by the sqrt of key dimension
        similarity_matrix /= scaling_factor
        
        # Apply softmax to get attention weights
        soft_matrix = torch.nn.functional.softmax(similarity_matrix, dim=-1)
        
        # Apply the attention weights to the value embeddings
        attention = torch.matmul(soft_matrix, value)  # Shape: [batch_size, seq_len, emb_dim]
        
        # Add the learnable bias and pass through the final layer
        attention += self.attn_embedding_bias  # Broadcasting bias across the sequence
        out = self.ffw(attention)  # Shape: [batch_size, seq_len, voc]
        
        return out


In [None]:
import torch

class SkipGramFoo(torch.nn.Module):
    def __init__(self, voc, emb, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = emb // num_heads  # Dimension per head
        assert emb % num_heads == 0, "Embedding dimension must be divisible by the number of heads"
        
        self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
        self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
        
        # Linear layers for multi-head query, key, and value
        self.linear_q = torch.nn.Linear(emb, emb)
        self.linear_k = torch.nn.Linear(emb, emb)
        self.linear_v = torch.nn.Linear(emb, emb)

        self.layer_norm = torch.nn.LayerNorm(emb)
        
        self.attn_embedding_bias = torch.nn.Parameter(torch.zeros(emb))

    def forward(self, inpt):
        emb = self.emb(inpt)  # Shape: [batch_size, emb_dim]
        
        # Transform embeddings for query, key, and value, then reshape for multi-head attention
        batch_size = emb.size(0)
        
        query = self.linear_q(emb).view(batch_size, self.num_heads, self.head_dim).transpose(0, 1)
        key = self.linear_k(emb).view(batch_size, self.num_heads, self.head_dim).transpose(0, 1)
        value = self.linear_v(emb).view(batch_size, self.num_heads, self.head_dim).transpose(0, 1)
        
        # Calculate attention scores and apply softmax
        scaling_factor = self.head_dim ** 0.5
        similarity_matrix = torch.matmul(query, key.transpose(-2, -1)) / scaling_factor
        soft_matrix = torch.nn.functional.softmax(similarity_matrix, dim=-1)
        
        # Apply attention weights to values and reshape back
        attention = torch.matmul(soft_matrix, value).transpose(0, 1).contiguous()
        attention = attention.view(batch_size, -1)  # Combine heads back to [batch_size, emb_dim]
        
        # Add learnable bias and pass through the final layer
        attention += self.attn_embedding_bias
        attention = self.layer_norm(attention)

        out = self.ffw(attention)  # Shape: [batch_size, voc]
        
        return out


In [157]:
args = (len(words_to_ids), 64,2)
mFoo = SkipGramFoo(*args)
print('mFoo', sum(p.numel() for p in mFoo.parameters()))
#learning rate thing 
opFoo = torch.optim.Adam(mFoo.parameters(), lr=0.003)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mFoo 16000


In [159]:
import torch
import more_itertools
import tqdm
import wandb

# Initialize W&B
wandb.init(project="word2vec_attention", name='bias weighting, with more softmax - text8 synthetic data 1')

# Set parameters
learning_rate = 0.001  # Define your learning rate
mFoo = mFoo.to(device)

# Set context size
context_size = 5  # Example context size
window_size = 5  # Total tokens in the window

# Initialize the optimizer
opFoo = torch.optim.Adam(mFoo.parameters(), lr=learning_rate)

# Instantiate the CrossEntropyLoss
criterion = torch.nn.CrossEntropyLoss()


for epoch in range(10):
    # Generate non-overlapping windows of exactly `window_size`
    wins = [tokens[i:i + window_size] for i in range(0, len(tokens[:50000]), window_size)]
    targets = [tokens_unmasked[i:i + window_size] for i in range(0, len(tokens_unmasked[:50000]), window_size)]
    prgs = tqdm.tqdm(wins, total=len(wins), desc=f"Epoch {epoch + 1}", leave=False)

    total_loss = 0.0  # Initialize total loss for the epoch

    for win, target in zip(prgs, targets):
        # Ensure the window has the correct size (in case of incomplete windows at the end)
        if len(win) < window_size or len(target) < window_size:
            continue  # Skip incomplete windows

        # Prepare input and target tensors for a batch
        inpt = torch.LongTensor(win).to(device)  # Masked tokens as input
        true_index = torch.LongTensor(target).to(device)  # Unmasked tokens as ground truth

        # Zero gradients
        opFoo.zero_grad()
        
        # Forward pass - ensure mFoo can handle a batch input
        out = mFoo(inpt)  # Assuming `out` has shape [batch_size, vocab_size]
        
        # Calculate the loss - criterion expects [batch_size, num_classes] and [batch_size] targets
        loss = criterion(out, true_index)
        
        # Backward pass and optimization
        loss.backward()
        opFoo.step()
        
        # Accumulate loss
        total_loss += loss.item()

        # Log the loss for this batch
        wandb.log({'loss': loss.item(), 'learning_rate': learning_rate})

    # Calculate and log average loss for the epoch
    average_loss = total_loss / len(wins) if len(wins) > 0 else 0
    wandb.log({'average_loss': average_loss})
    
# Finish the W&B logging
# Save the model's state dict
torch.save(mFoo.state_dict(), 'model.pth')
wandb.finish()


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
average_loss,▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▆▅▆▁▄▁▁▁▅▁▂▂█▄▃▁▅▁▃▂▁▂▁▄▁▁▂▃▅▂▁▃▁▂▂▇▂▂▂▁

0,1
average_loss,0.27832
learning_rate,0.001
loss,0.31183


                                                               

VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
average_loss,█▆▅▄▃▂▂▂▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▄▄▁▂▄█▄▁▂▁▁▄▅▁▄▃▁▁█▂▁▄▁▁▁▁▂▁▃▁▁▃▁▁▁▃▁▁▁▁

0,1
average_loss,0.17779
learning_rate,0.001
loss,0.11249


In [160]:
# Load in the weights and initialise the model in eval mode

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = (len(words_to_ids), 64,2)

model = SkipGramFoo(*args).to("cpu")
model.load_state_dict(torch.load('model.pth'))
model.eval() 


SkipGramFoo(
  (emb): Embedding(27, 64)
  (ffw): Linear(in_features=64, out_features=27, bias=False)
  (linear_q): Linear(in_features=64, out_features=64, bias=True)
  (linear_k): Linear(in_features=64, out_features=64, bias=True)
  (linear_v): Linear(in_features=64, out_features=64, bias=True)
)

"hello my name is Omar" -> "hello my name is Omar"
"Hello __ name is ____" -> "hello my name is Omar"
"Hi, my name is Omar, and this is my car, which i call ___'s car"

"is ___" -> "is Omar"

In [None]:
import datasets


from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-2-v1")

split = 'test'

merged_text = " ".join(ds[split][:1000])

tkns_ds = sentence_piece.encode(merged_text)

# [4, 234, 23, 12321 ... , 12, 41]

windows = more_itertools.windowed(tkns_ds, 512)

for window in windows:
    # Mask the window
    ...

    # Run masked window
    preds = model(masked)

    # Compare with unmasked window
    loss = cross_entropy(preds, window)

    # Backprop
    ...
    loss.backward()
    optimiser.step()


# Test



In [None]:
test_input = torch.LongTensor(tokens[5:10]).to(device)
pred = model(test_input)

labels = torch.argmax(pred, dim=-1)
print('what the actual tokens should be',tokens_unmasked[5:10])
print('what was missing ',test_input)
print('model prediction',labels)
predicted_tokens = labels.tolist()
predicted_words = [ids_to_words[token_id] for token_id in predicted_tokens]
print('Model prediction (in words):', predicted_words)

what the actual tokens should be [23, 23, 23, 14, 14]
what was missing  tensor([ 0, 23, 23, 14, 14])
model prediction tensor([23, 23, 23, 14, 14])
Model prediction (in words): ['b', 'b', 'b', 'c', 'c']


In [3]:
pip install torchvision

Collecting torchvisionNote: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.2.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip



  Downloading torchvision-0.20.1-cp310-cp310-win_amd64.whl (1.6 MB)
     ---------------------------------------- 1.6/1.6 MB 1.8 MB/s eta 0:00:00
Collecting torch==2.5.1
  Downloading torch-2.5.1-cp310-cp310-win_amd64.whl (203.1 MB)
     ------------------------------------- 203.1/203.1 MB 11.1 MB/s eta 0:00:00
Collecting typing-extensions>=4.8.0
  Downloading typing_extensions-4.12.2-py3-none-any.whl (37 kB)
Collecting sympy==1.13.1
  Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)
     ---------------------------------------- 6.2/6.2 MB 23.3 MB/s eta 0:00:00
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ------------------------------------- 536.2/536.2 kB 32.9 MB/s eta 0:00:00
Installing collected packages: mpmath, typing-extensions, sympy, torch, torchvision
  Attempting uninstall: typing-extensions
    Found existing installation: typing_extensions 4.4.0
    Uninstalling typing_extensions-4.4.0:
      Successfully uninstalled typing_ext

In [None]:
    
import torch

class SkipGramFoo(torch.nn.Module):
    def __init__(self, voc, emb, ctx):
        super().__init__()
        # Define embeddings and feedforward layers
        self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
        self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
        
        # Linear layers for query, key, and value transformations
        self.linear_q = torch.nn.Linear(emb, emb)
        self.linear_k = torch.nn.Linear(emb, emb)
        self.linear_v = torch.nn.Linear(emb, emb)
        
        # Learnable bias for attention embeddings
        self.attn_embedding_bias = torch.nn.Parameter(torch.zeros(emb))

    def forward(self, inpt):
        # Embed the input tokens
        emb = self.emb(inpt)  # Shape: [batch_size, seq_len, emb_dim]
        
        # Transform embeddings for query, key, and value
        query = self.linear_q(emb)   # Shape: [batch_size, seq_len, emb_dim]
        key = self.linear_k(emb)     # Shape: [batch_size, seq_len, emb_dim]
        value = self.linear_v(emb)   # Shape: [batch_size, seq_len, emb_dim]
        
        # Calculate the similarity (attention scores)
        # Here, query and key must be transposed to match in dimensions for matmul
        similarity_matrix = torch.matmul(query, key.transpose(-2, -1))  # Shape: [batch_size, seq_len, seq_len]
        
        # # Scale the similarity matrix
        # scaling_factor = key.size(-1) ** 0.5  # Normalize by the sqrt of key dimension
        # similarity_matrix /= scaling_factor
        
        # Apply softmax to get attention weights
        soft_matrix = torch.nn.functional.softmax(similarity_matrix, dim=-1)
        
        # Create a random binary mask with the same shape as soft_matrix
        # For a batch, we create a mask for each element
        # mask = torch.randint(0, 2, soft_matrix.shape).to(soft_matrix.device)
        # soft_matrix_masked = soft_matrix * mask  # Element-wise masking
        
        # Apply the attention weights to the value embeddings
        attention = torch.matmul(soft_matrix, value)  # Shape: [batch_size, seq_len, emb_dim]
        
        # Add the learnable bias and pass through the final layer
        attention += self.attn_embedding_bias  # Broadcasting bias across the sequence
        out = self.ffw(attention)  # Shape: [batch_size, seq_len, voc]
        
        return out




In [None]:
# # mFoo(mFoo.emb(torch.tensor(1)), mFoo.emb(torch.tensor(1)))
# v1 = mFoo.emb(torch.tensor(1))
# v2 = mFoo.emb(torch.tensor(0))

In [None]:
# out = mFoo(inpt, trgs)

In [None]:
# print (inpt)
# print (trgs)

In [None]:
# print (out.shape) # Vocabulary Size - this output contains the probabiltiy distribution after the activation function.
# ''' This means we are missing the activation function in the model.'''
# print (max(out[0])) # This is the token selected 
# print (out)
# words_to_ids['anarchism']

In [None]:
# print(v1.shape)
# print(v2.shape)

In [None]:
# v22 = torch.unsqueeze(v2, 1)
# print(v22.T.shape)


In [None]:
# Here we need to create a function for the syn data that we can load the model and test 