<h2> Image Embedding </h2>

In [1]:
!pip install datasets
!pip install huggingface_hub
!pip install tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import login
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from tqdm import tqdm
from PIL import Image
from transformers import AutoTokenizer, AutoModel
class PatchEmbedings(torch.nn.Module):
    def __init__(self, img_size = 224, patch_size = 16, hidden_size = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        #CONVOLOTUION FOR PATCH EXTRACTION
        self.conv = nn.Conv2d(in_channels = 3, out_channels = hidden_size, kernel_size = patch_size, stride = patch_size)

        nn.init.xavier_uniform_(self.conv.weight)
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x):

        if x.size(2) != self.img_size or x.size(3) != self.img_size:
            raise ValueError(f"Input image size is different than model trained one {x.shape}. \n It must be {self.img_size} x {self.img_size}")
        x = self.conv(x)
        x = x.flatten(2) #This way I remain the batches and channels unchanged and since the H&W are now H/patc_size = num_patches
        x = x.transpose(1, 2) #NOW THE TENSOR IS (Num_batches, num_patches, hidden_size_channels)

        return x


Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

<h2>Multi-Head Self Attention mechanism</h2>

In [39]:
class HeadAttentionLayer(nn.Module):
    def __init__(self, dropout, is_decoder, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd,head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.dropout = nn.Dropout(dropout)
        self.is_decoder = is_decoder

    def forward(self,x):
      #FIRST OF ALL, I GET THE VALUES FROM THE INPUT SHAPE, THEN I INSTANTIATE THE KEY,QUERY,VALUE LAYERS

        batch_size,seq_length, num_channels = x.shape
        key = self.key(x)
        query = self.query(x)
        value = self.value(x)

      #THEN I CALCULATE DOT PRODUCT BETWEEN EACH QUERY AND KEYS
      #FINALLY I CHANGE THE SCALE OF THE OUPUT TO 1/SQRT(NUM_CHANNELS)
      #THIS LAST STEP IS DONE BECAUSE DOT PRODUCT VALUES CAN GET TO PRETTY HIGH VALUES, THEN SOFTMAX WILL HAVE VERY SMALL GRANDIENTS
      #WITH THIS WE AVOID VANISHING / EXPLODING GRADIENTS
        wei = query @ key.transpose(-2,-1) * (num_channels ** -0.5)

        if self.is_decoder:
            tril = torch.tril(torch.ones(seq_length, seq_length, dtype = torch.bool, device = x.device))
            wei = wei.masked_fill(tril == 0, float("-inf"))

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ value

        return out

In [None]:
class MultiModalProjector(nn.Module):
    def __init__(self, n_embd, image_embed_dim, dropout = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(image_embed_dim, image_embed_dim * 4),
            nn.GeLU(),
            nn.Linear(image_embed_dim*4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self,x):
        x = self.net(x)
        return x

In [None]:
class VisionLanguageModel(nn.Module):
    def __init__(self, n_embd, image_embed_dim, vocab_size,
                 n_layer, img_size, patch_size, num_heads,
                 num_blks, emb_dropout, blk_dropout):
        super().__init__()
        num_hiddens = image_embed_dim
        assert num_hiddens % num_heads == 0

        self.vision_encoder = PatchEmbedings(96,16,512)
        self.decoder = Head(0.1,True,512,)

<h2> Helper functions </h2>

In [3]:
def image_embedding(image,patch_size):
    print(f"Before unfold {image.shape}")
    patches = image.unfold(2,size = patch_size, step = patch_size).unfold(3,size = patch_size, step = patch_size)
    num_patches = ()
    num_patches_w = image.shape[2]//patch_size
    num_patches_h = image.shape[3]//patch_size
    num_patches = num_patches_h * num_patches_w
    print(f"After unfold {patches.shape}")
    #TODO .CONTIGUOUS IS NECCESARRY FOR .VIEW
    patches = patches.permute(0,2,3,1,4,5).contiguous()
    print(patches.shape)
    patches = patches.view(image.shape[0], num_patches, -1)
    print(f"Patches shape: {patches.shape}")


<h2> Main code </h2>

In [37]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self,features,labels,patch_embedding):
    #TRAINING EXAMPLES
    self.features = features
    self.labels = labels
    self.patch_embedding = patch_embedding

    #TRANSFORMS
    self.train_transforms = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(p = 0.4),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.5,0.5,0.5], std = [0.5, 0.5 ,0.5])
    ])

    #TOKENZATION
    self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    self.embedding = nn.Embedding(self.tokenizer.vocab_size,768)


  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):

    features_iter = self.features[idx]
    labels_iter = self.labels[idx]

    label_token = self.tokenizer(labels_iter,padding = "max_length", max_length=14, truncation = True, return_tensors = "pt")
    feature_transform = self.train_transforms(features_iter)

    #I WILL ADD ONE COORDINATE TO USE AS BATCH_SIZE SO I CAN PASS THIS INTO PATCH_EMBEDDING
    feature_transform = feature_transform.view(1, feature_transform.shape[0],feature_transform.shape[1],feature_transform.shape[2])
    feature_transform = self.patch_embedding(feature_transform)
    feature_transform = feature_transform.squeeze(0)


    label_tokens = self.embedding(label_token['input_ids'])
    label_tokens_final = label_tokens.unsqueeze(0)

    attention_mask = label_token['attention_mask']
    attention_mask = self.embedding(attention_mask).unsqueeze(0)

    return feature_transform,label_tokens_final, attention_mask



def model_training(dataloader, patch_embedding, head_attention):

  for (x,y,z) in tqdm(dataloader, desc = "TRAINING"):

    x = x.to("cpu")
    y = y.to("cpu")
    z = z.to("cpu")

    print(f"Head attention input {x.shape}")
    output = head_attention(x)
    print(f"Output of head attention layers {output}")
    #prediction = head(x)
    #print(f"Prediction shape {prediction.shape}")
  return x,y,z

def tokenizer_trials(tokens):
  max_length = 0
  for dict in tokens:
    for sentences in dict['tokens']:
      if len(sentences) > max_length:
        max_length = len(sentences)

  print(max_length)








In [6]:
from google.colab import drive
import os

login(token="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")
dataset = load_dataset("xcpan/coco2017", split ="train")
print(dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/329 [00:00<?, ?B/s]

train-00000-of-00011.parquet:   0%|          | 0.00/477M [00:00<?, ?B/s]

train-00001-of-00011.parquet:   0%|          | 0.00/469M [00:00<?, ?B/s]

train-00002-of-00011.parquet:   0%|          | 0.00/481M [00:00<?, ?B/s]

train-00003-of-00011.parquet:   0%|          | 0.00/480M [00:00<?, ?B/s]

train-00004-of-00011.parquet:   0%|          | 0.00/482M [00:00<?, ?B/s]

train-00005-of-00011.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

train-00006-of-00011.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

train-00007-of-00011.parquet:   0%|          | 0.00/478M [00:00<?, ?B/s]

train-00008-of-00011.parquet:   0%|          | 0.00/479M [00:00<?, ?B/s]

train-00009-of-00011.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

train-00010-of-00011.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/102512 [00:00<?, ? examples/s]

Dataset({
    features: ['caption', 'image'],
    num_rows: 102512
})


In [40]:
from datasets import load_from_disk
from PIL import Image


small_dataset = dataset.shuffle(seed=42).select(range(500))
patch = PatchEmbedings()
head_attention = HeadAttentionLayer(0.1, True, 768, 5)
#TODO USE THIS TO TRAIN WITH A LOWER VERSION OF THE DATASET
  #dataset = load_dataset("HuggingFaceM4/COCO", split = "train")

images = small_dataset['image']
caption = small_dataset['caption']
#images, conversations = data_checking(images, conversations)

#new_tokens = [token['tokens'] for token in tokens]
dataset_class = Dataset(images, caption, patch)
training_loader = DataLoader(dataset_class, batch_size = 32, shuffle=True)

feature_transform, final_tokens, attention_mask = model_training(training_loader, patch, head_attention)
print(f"Post patch embedding checking \n Feature shape {feature_transform.shape} \n Final tokens shape {final_tokens.shape} \n Attention mask shape {attention_mask.shape}")

TRAINING:   6%|▋         | 1/16 [00:00<00:06,  2.39it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 1.0141e-02, -5.3605e-02,  4.7145e-03,  1.0700e-01,  5.1938e-03],
         [ 8.6504e-03, -6.2088e-02,  9.5750e-03,  1.1886e-01, -1.4345e-03],
         [ 8.9725e-03, -6.6526e-02,  1.2002e-02,  1.2510e-01, -5.6389e-03],
         ...,
         [-4.3021e-03,  1.7306e-02, -1.3468e-03, -3.1940e-02,  3.2085e-03],
         [-5.4102e-03,  2.0063e-02, -1.5259e-03, -3.8690e-02,  3.4209e-03],
         [-5.6641e-03,  1.7490e-02, -9.7320e-04, -3.4859e-02,  2.5050e-03]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.0302e-02, -6.6777e-02,  8.4923e-03,  1.1627e-01, -1.0691e-02],
         [ 1.0892e-02, -5.9871e-02,  1.4050e-02,  1.1137e-01, -1.4218e-02],
         ...,
         [ 4.1756e-03, -1.2960e-02,  1.0363e-02,  2.1332e-02,  1.3017e-03],
         [ 4.3853e-03, -1.2185e-02,  1.1032e-02,  1.8403e-02,  1.4457e-03],
         [ 4.5599e-03, -1.2678e-02,  1.1204e-02,  2.00

TRAINING:  12%|█▎        | 2/16 [00:00<00:05,  2.51it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 0.0267, -0.0668,  0.0003,  0.1666, -0.0066],
         [ 0.0211, -0.0569,  0.0015,  0.1503, -0.0020],
         [ 0.0141, -0.0379,  0.0010,  0.1002, -0.0013],
         ...,
         [ 0.0119, -0.0219,  0.0149,  0.0289,  0.0043],
         [ 0.0119, -0.0224,  0.0152,  0.0300,  0.0039],
         [ 0.0119, -0.0222,  0.0149,  0.0296,  0.0035]],

        [[ 0.0498,  0.0410,  0.0033,  0.0773,  0.0243],
         [ 0.0196,  0.0326, -0.0105,  0.0933,  0.0055],
         [ 0.0083,  0.0218, -0.0124,  0.0727,  0.0009],
         ...,
         [ 0.0061, -0.0301,  0.0103,  0.0567, -0.0058],
         [ 0.0059, -0.0308,  0.0102,  0.0565, -0.0057],
         [ 0.0067, -0.0326,  0.0109,  0.0600, -0.0066]],

        [[ 0.0172, -0.0728,  0.0035,  0.1603, -0.0098],
         [ 0.0192, -0.0540,  0.0037,  0.1449, -0.0031],
         [ 0.0217, -0.0493,  0.0083,  0.1463, -0.0024],
         ...,
         [ 0.0037, -0.0124,  0.005

TRAINING:  19%|█▉        | 3/16 [00:01<00:05,  2.58it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-1.3843e-03, -7.1905e-02,  8.7948e-03,  1.0249e-01,  1.5809e-03],
         [ 7.7594e-03, -6.4377e-02,  5.5009e-03,  7.2184e-02, -7.1092e-04],
         [ 1.0948e-02, -5.1311e-02,  3.6392e-03,  4.9205e-02,  4.7381e-03],
         ...,
         [ 1.2730e-02, -3.9283e-02,  7.1923e-03,  6.2963e-02, -2.6327e-03],
         [ 1.3007e-02, -4.0049e-02,  7.0700e-03,  6.2954e-02, -2.4589e-03],
         [ 1.3615e-02, -4.1246e-02,  7.7410e-03,  6.5515e-02, -2.8838e-03]],

        [[-2.0562e-02, -4.6382e-02, -2.0824e-03,  7.3871e-02,  6.7506e-03],
         [-1.3625e-02, -1.3721e-02, -3.3543e-03,  1.8086e-02,  5.4000e-03],
         [-9.0831e-03, -9.1473e-03, -2.2362e-03,  1.2057e-02,  3.6000e-03],
         ...,
         [ 4.5652e-04,  7.2290e-04,  5.7234e-03, -1.2522e-02,  4.0293e-03],
         [ 8.3435e-07,  1.6116e-03,  4.6774e-03, -1.5782e-02,  4.1718e-03],
         [ 7.6754e-05,  3.0737e-05,  5.3116e-03, -1.33

TRAINING:  25%|██▌       | 4/16 [00:01<00:04,  2.60it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-1.8342e-02, -4.5482e-02,  2.3348e-02,  1.0591e-01, -1.6346e-02],
         [-6.6406e-03, -1.7904e-02, -5.1540e-03,  3.6204e-02,  2.2424e-03],
         [-3.7663e-02, -3.1334e-02, -1.4271e-02,  3.4750e-02,  5.1034e-03],
         ...,
         [-4.2274e-04, -5.1450e-03,  3.6892e-03,  6.0819e-03, -2.3961e-04],
         [-4.6587e-04, -4.7239e-03,  2.8496e-03,  5.9484e-03, -2.6506e-04],
         [-6.1590e-04, -4.4509e-03,  3.9662e-03,  4.4015e-03, -9.0004e-04]],

        [[ 2.5864e-02, -6.8517e-02, -1.2230e-02,  1.5986e-01, -1.0719e-03],
         [ 2.7836e-02, -3.8094e-02,  1.8701e-03,  1.4047e-01, -3.2552e-03],
         [ 3.1231e-03, -1.0322e-02,  1.1513e-02,  3.8326e-02, -1.0318e-02],
         ...,
         [ 2.6110e-03, -1.8216e-02,  6.6519e-03,  4.3268e-02, -8.9259e-03],
         [ 2.4567e-03, -1.7847e-02,  6.3697e-03,  4.2294e-02, -8.8799e-03],
         [ 3.3256e-03, -2.1308e-02,  7.7113e-03,  4.91

TRAINING:  31%|███▏      | 5/16 [00:01<00:04,  2.61it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-0.0064, -0.0042,  0.0040,  0.0461, -0.0029],
         [-0.0198, -0.0010,  0.0038,  0.0147,  0.0087],
         [-0.0153, -0.0196, -0.0072,  0.0160,  0.0218],
         ...,
         [ 0.0117, -0.0101,  0.0075, -0.0056,  0.0100],
         [ 0.0111, -0.0102,  0.0076, -0.0064,  0.0103],
         [ 0.0111, -0.0089,  0.0077, -0.0092,  0.0105]],

        [[ 0.0236, -0.0938,  0.0075,  0.0477, -0.0272],
         [ 0.0029, -0.0109, -0.0060, -0.0452, -0.0076],
         [-0.0040,  0.0166, -0.0106, -0.0763, -0.0009],
         ...,
         [-0.0056,  0.0264, -0.0056, -0.0602,  0.0076],
         [-0.0054,  0.0260, -0.0057, -0.0600,  0.0076],
         [-0.0053,  0.0251, -0.0058, -0.0585,  0.0075]],

        [[ 0.0156, -0.1091, -0.0343,  0.1037,  0.0024],
         [ 0.0078, -0.0545, -0.0171,  0.0519,  0.0012],
         [-0.0040,  0.0059, -0.0204, -0.0545,  0.0119],
         ...,
         [-0.0002,  0.0017,  0.003

TRAINING:  38%|███▊      | 6/16 [00:02<00:03,  2.64it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-2.4291e-04, -8.1128e-02,  1.8049e-02,  8.7624e-02,  1.0988e-02],
         [ 2.0759e-03, -5.2736e-02,  1.3128e-02,  5.2073e-02,  1.1345e-02],
         [-1.1904e-03, -5.2954e-02,  1.0338e-02,  4.6940e-02,  6.3780e-03],
         ...,
         [ 9.6217e-03, -3.0325e-02,  9.5299e-03,  4.2861e-02,  6.1723e-04],
         [ 9.9818e-03, -3.1535e-02,  1.0225e-02,  4.4350e-02,  6.0496e-04],
         [ 1.0412e-02, -3.1370e-02,  1.0049e-02,  4.4287e-02,  6.0068e-04]],

        [[ 1.4369e-02, -1.0176e-01, -1.5666e-04,  4.2841e-02,  4.9581e-03],
         [ 9.9026e-04, -2.1001e-02, -9.2254e-03, -3.9718e-02,  9.1926e-03],
         [-3.4578e-03,  5.9138e-03, -1.2270e-02, -6.7206e-02,  1.0526e-02],
         ...,
         [-3.0545e-03,  2.3294e-02, -7.8358e-03, -5.9225e-02,  9.2763e-03],
         [-2.6621e-03,  2.3661e-02, -7.3875e-03, -6.2153e-02,  9.5330e-03],
         [-2.7021e-03,  2.1156e-02, -7.4819e-03, -5.66

TRAINING:  44%|████▍     | 7/16 [00:02<00:03,  2.63it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-0.0158, -0.0586, -0.0107,  0.0950, -0.0282],
         [-0.0269, -0.0785, -0.0175,  0.0805, -0.0161],
         [-0.0257, -0.0751, -0.0072,  0.0675, -0.0002],
         ...,
         [ 0.0090, -0.0196,  0.0116,  0.0319,  0.0009],
         [ 0.0084, -0.0189,  0.0127,  0.0309,  0.0010],
         [ 0.0083, -0.0191,  0.0130,  0.0308,  0.0009]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.0140,  0.0086,  0.0220, -0.0113,  0.0236],
         [-0.0023, -0.0096,  0.0061,  0.0086,  0.0080],
         ...,
         [ 0.0127, -0.0147,  0.0124,  0.0129,  0.0079],
         [ 0.0130, -0.0168,  0.0127,  0.0146,  0.0079],
         [ 0.0116, -0.0130,  0.0118,  0.0111,  0.0074]],

        [[ 0.0031, -0.0846,  0.0103,  0.1185, -0.0184],
         [ 0.0004, -0.0474,  0.0159,  0.0838, -0.0210],
         [ 0.0003, -0.0316,  0.0106,  0.0559, -0.0140],
         ...,
         [-0.0060,  0.0037, -0.004

TRAINING:  50%|█████     | 8/16 [00:03<00:03,  2.64it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 1.3135e-02, -3.5612e-02,  2.7732e-04,  1.3822e-01,  9.0842e-03],
         [ 2.1161e-02, -3.1412e-02,  1.1721e-02,  1.3334e-01,  1.7877e-03],
         [ 1.4974e-02, -3.1965e-02, -6.3353e-04,  1.2460e-01, -4.3150e-03],
         ...,
         [ 1.0603e-02, -2.8301e-02,  6.3089e-03,  3.1869e-02,  2.3752e-03],
         [ 1.1500e-02, -2.9260e-02,  7.1796e-03,  3.3050e-02,  2.0266e-03],
         [ 1.0977e-02, -2.7018e-02,  6.6210e-03,  3.0896e-02,  2.0254e-03]],

        [[ 2.1381e-02,  1.5112e-02, -1.2208e-02, -2.5830e-02,  4.0948e-02],
         [-1.5167e-02, -5.7958e-03,  1.6383e-02,  2.3947e-02, -8.3352e-03],
         [-1.3576e-02, -3.4629e-03,  1.7137e-02,  6.2213e-04,  4.7111e-03],
         ...,
         [ 2.1105e-03, -1.4247e-02,  2.6925e-03,  1.7529e-02, -4.9690e-04],
         [ 1.8304e-03, -1.4491e-02,  1.9242e-03,  1.8002e-02, -9.1652e-04],
         [ 3.0430e-03, -1.5005e-02,  2.4644e-03,  1.83

TRAINING:  56%|█████▋    | 9/16 [00:03<00:02,  2.67it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 0.0372, -0.0952, -0.0578,  0.1374,  0.0164],
         [ 0.0138, -0.0191, -0.0343,  0.0071,  0.0157],
         [ 0.0058,  0.0067, -0.0263, -0.0374,  0.0154],
         ...,
         [ 0.0023, -0.0029,  0.0027,  0.0028, -0.0010],
         [ 0.0027, -0.0036,  0.0031,  0.0018, -0.0007],
         [ 0.0031, -0.0060,  0.0030,  0.0061, -0.0016]],

        [[ 0.0398, -0.0046,  0.0100,  0.1135,  0.0066],
         [ 0.0208, -0.0053, -0.0015,  0.1165, -0.0019],
         [ 0.0127, -0.0069,  0.0004,  0.0945, -0.0018],
         ...,
         [ 0.0079, -0.0151,  0.0085,  0.0139,  0.0029],
         [ 0.0081, -0.0162,  0.0082,  0.0164,  0.0022],
         [ 0.0090, -0.0185,  0.0094,  0.0202,  0.0026]],

        [[ 0.0068, -0.0347,  0.0023,  0.0922, -0.0042],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0069, -0.0321,  0.0031,  0.0714, -0.0055],
         ...,
         [ 0.0061, -0.0190,  0.006

TRAINING:  62%|██████▎   | 10/16 [00:03<00:02,  2.37it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-1.5863e-02, -1.1136e-02, -2.1538e-02,  6.9913e-02, -1.2380e-02],
         [-1.1714e-02, -2.3259e-02, -1.0888e-02,  7.5963e-02, -5.5994e-03],
         [-7.8092e-03, -1.5507e-02, -7.2587e-03,  5.0642e-02, -3.7329e-03],
         ...,
         [ 6.1233e-04, -1.1069e-02,  5.3679e-03,  2.6524e-02, -1.9989e-03],
         [ 6.1838e-04, -1.2130e-02,  5.2513e-03,  2.6517e-02, -1.4885e-03],
         [ 9.1521e-04, -1.2670e-02,  5.7992e-03,  2.8186e-02, -1.9532e-03]],

        [[-4.2403e-02,  1.2366e-02, -4.5873e-02,  1.4308e-03,  1.3742e-02],
         [-4.0079e-02, -2.2151e-04, -4.1387e-02,  4.6413e-03,  2.1798e-02],
         [-5.4500e-02,  3.9025e-03, -3.6981e-02, -1.3666e-02,  1.7142e-02],
         ...,
         [-4.9050e-03,  1.5547e-02, -3.7435e-03, -2.5694e-02,  1.8699e-03],
         [-4.5103e-03,  1.4157e-02, -4.0355e-03, -2.4181e-02,  2.2453e-03],
         [-5.2232e-03,  1.3749e-02, -3.5613e-03, -2.44

TRAINING:  69%|██████▉   | 11/16 [00:04<00:02,  2.17it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-1.8254e-02,  3.0182e-02, -7.5561e-04, -5.0348e-02,  9.4938e-03],
         [-1.8606e-02,  2.7119e-02, -6.6748e-03, -6.1605e-02,  9.8936e-03],
         [-1.8299e-02,  2.0504e-02, -5.8357e-03, -6.1765e-02,  9.6763e-03],
         ...,
         [ 3.1300e-03, -5.3075e-03,  1.3107e-03,  8.5357e-03,  1.9981e-04],
         [ 3.2020e-03, -6.7405e-03,  1.4635e-03,  1.0835e-02, -6.5091e-04],
         [ 3.3369e-03, -5.6892e-03,  1.8412e-03,  8.4081e-03, -2.7370e-04]],

        [[-2.5378e-02, -7.5437e-02, -6.5417e-03,  7.6964e-02, -2.6335e-02],
         [-2.0115e-02, -1.4938e-02, -1.0909e-02,  1.5132e-03, -1.1247e-02],
         [-1.8596e-02,  5.4807e-03, -1.2527e-02, -2.4952e-02, -6.2468e-03],
         ...,
         [-1.3443e-03, -7.3199e-03,  2.2546e-03,  1.7975e-02, -3.6531e-03],
         [-4.8311e-04, -7.7434e-03,  2.6455e-03,  1.9310e-02, -3.5853e-03],
         [-6.9618e-04, -8.8169e-03,  2.4778e-03,  2.03

TRAINING:  75%|███████▌  | 12/16 [00:05<00:01,  2.04it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 1.6524e-02, -6.4789e-02, -1.1348e-02,  1.6010e-01, -2.5300e-03],
         [ 2.0516e-02, -2.5878e-02, -9.5824e-03,  1.2865e-01,  1.0367e-02],
         [ 2.2106e-02, -1.0216e-02, -6.6423e-04,  1.2472e-01,  5.8728e-03],
         ...,
         [ 4.7540e-03, -1.8803e-02,  7.6693e-03,  4.0700e-02, -2.9700e-03],
         [ 4.9792e-03, -1.9945e-02,  7.6349e-03,  4.1356e-02, -3.0150e-03],
         [ 4.8868e-03, -1.9456e-02,  7.0619e-03,  3.8327e-02, -2.6369e-03]],

        [[ 2.3893e-02, -5.9445e-02,  1.8343e-02,  9.2589e-02,  1.3809e-02],
         [ 1.9250e-02, -5.4089e-02,  1.9825e-02,  8.1833e-02,  2.0275e-02],
         [ 1.7359e-02, -5.9666e-02,  1.8904e-02,  7.6725e-02,  1.8744e-02],
         ...,
         [ 1.7386e-02, -1.5695e-02,  1.6578e-02, -1.2615e-02,  1.7317e-02],
         [ 1.7493e-02, -1.5622e-02,  1.6872e-02, -1.1693e-02,  1.6640e-02],
         [ 1.8300e-02, -1.6982e-02,  1.7628e-02, -1.19

TRAINING:  81%|████████▏ | 13/16 [00:05<00:01,  1.96it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 4.5263e-03, -8.8515e-02, -6.8705e-03,  9.6899e-02, -2.3984e-02],
         [-3.9040e-03, -3.3798e-02, -9.2723e-03,  3.1661e-02, -1.2580e-02],
         [-6.3170e-03, -1.3484e-02, -9.8257e-03,  6.6886e-03, -8.9319e-03],
         ...,
         [-1.1836e-02,  3.6336e-02, -6.2406e-03, -6.6462e-02,  4.3713e-03],
         [-1.1666e-02,  3.8108e-02, -6.3962e-03, -6.8908e-02,  4.4019e-03],
         [-1.1553e-02,  3.6961e-02, -5.6685e-03, -6.6402e-02,  4.1297e-03]],

        [[ 1.8450e-02, -7.2681e-02,  1.6922e-02,  1.5658e-01, -1.4811e-02],
         [ 9.2249e-03, -3.6341e-02,  8.4612e-03,  7.8290e-02, -7.4054e-03],
         [ 1.6221e-02, -5.9030e-02,  3.2829e-03,  1.2423e-01, -1.6948e-03],
         ...,
         [ 8.1625e-03, -2.5523e-02,  1.0080e-02,  4.1627e-02, -1.2379e-03],
         [ 7.7677e-03, -2.6830e-02,  9.4298e-03,  4.2127e-02, -1.4458e-03],
         [ 8.3869e-03, -2.6926e-02,  1.0006e-02,  4.29

TRAINING:  88%|████████▊ | 14/16 [00:06<00:01,  1.89it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[-1.8807e-02, -2.6695e-02, -1.6992e-02,  3.8480e-02,  2.2093e-02],
         [-2.8951e-02, -1.2538e-02, -1.4038e-02,  2.3346e-02,  1.5073e-02],
         [-6.2688e-03, -8.8979e-03, -5.6640e-03,  1.2826e-02,  7.3642e-03],
         ...,
         [-3.5152e-03,  1.4472e-03, -1.0807e-04, -5.6911e-04, -2.2550e-03],
         [-2.9980e-03,  3.3781e-05,  5.3492e-05,  6.8999e-04, -2.5201e-03],
         [-3.0842e-03,  9.9535e-04, -5.0533e-04, -3.8882e-04, -2.3892e-03]],

        [[-2.8493e-02, -1.3684e-02, -4.7527e-03, -2.3264e-02, -3.2977e-03],
         [-2.5581e-02, -1.6601e-02, -6.6834e-03, -1.6265e-02, -2.0957e-03],
         [-2.5583e-02, -1.9414e-02, -6.7889e-03, -1.1954e-02, -3.0665e-03],
         ...,
         [ 2.2811e-03, -7.6374e-03,  4.3547e-04,  1.0362e-02,  1.2765e-03],
         [ 1.7366e-03, -5.6816e-03,  7.0590e-04,  7.3784e-03,  1.2692e-03],
         [ 2.1528e-03, -8.2603e-03,  1.1387e-03,  1.05

TRAINING:  94%|█████████▍| 15/16 [00:06<00:00,  2.04it/s]

Head attention input torch.Size([32, 196, 768])
Output of head attention layers tensor([[[ 0.0234,  0.0466, -0.0058,  0.0765,  0.0313],
         [ 0.0206,  0.0482,  0.0025,  0.0966,  0.0157],
         [ 0.0057,  0.0304, -0.0182,  0.0845,  0.0126],
         ...,
         [ 0.0039, -0.0129,  0.0050,  0.0242, -0.0012],
         [ 0.0045, -0.0126,  0.0054,  0.0239, -0.0011],
         [ 0.0039, -0.0123,  0.0052,  0.0232, -0.0015]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0016,  0.0182, -0.0232,  0.1012, -0.0088],
         [-0.0034,  0.0191, -0.0199,  0.0623,  0.0034],
         ...,
         [-0.0011, -0.0002,  0.0012,  0.0018, -0.0026],
         [-0.0015,  0.0012,  0.0004, -0.0012, -0.0020],
         [-0.0011, -0.0013,  0.0009,  0.0042, -0.0027]],

        [[ 0.0223, -0.0251,  0.0041,  0.1287,  0.0124],
         [ 0.0123, -0.0091,  0.0079,  0.0651, -0.0043],
         [ 0.0105, -0.0244, -0.0057,  0.1276, -0.0085],
         ...,
         [-0.0053,  0.0274, -0.007

TRAINING: 100%|██████████| 16/16 [00:06<00:00,  2.34it/s]

Head attention input torch.Size([20, 196, 768])
Output of head attention layers tensor([[[ 0.0176, -0.0802,  0.0166,  0.1606, -0.0167],
         [ 0.0130, -0.0769,  0.0190,  0.1480, -0.0148],
         [ 0.0133, -0.0655,  0.0137,  0.1309, -0.0107],
         ...,
         [ 0.0047, -0.0261,  0.0117,  0.0562, -0.0067],
         [ 0.0050, -0.0268,  0.0118,  0.0569, -0.0071],
         [ 0.0055, -0.0255,  0.0112,  0.0544, -0.0075]],

        [[ 0.0047, -0.0813, -0.0109,  0.1382, -0.0327],
         [ 0.0024, -0.0407, -0.0054,  0.0691, -0.0164],
         [ 0.0059, -0.0492,  0.0015,  0.0879, -0.0174],
         ...,
         [ 0.0020, -0.0190,  0.0039,  0.0428, -0.0070],
         [ 0.0024, -0.0192,  0.0036,  0.0409, -0.0062],
         [ 0.0025, -0.0190,  0.0041,  0.0433, -0.0069]],

        [[-0.0290, -0.0219, -0.0141,  0.0366,  0.0089],
         [-0.0213, -0.0199, -0.0211,  0.0395,  0.0137],
         [-0.0274, -0.0150, -0.0156,  0.0326,  0.0099],
         ...,
         [-0.0007, -0.0045,  0.001


