In [1]:
import os
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn.functional as F
# torch.multiprocessing.set_start_method('spawn')
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch import optim

import clip
# from transformers import CLIPProcessor, CLIPModel

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CONTEXT_LENGTH = 77
SAMPLE_SIZE = None
TRAIN_PROP = 0.7
VAL_PROP = 0.2
BATCH_SIZE = 4
NUM_WORKERS = 0
EPOCHS = 4
LEARNING_RATE = 5e-5

# Params same as paper
BETAS = (0.9,0.98)
EPS=1e-6
WEIGHT_DECAY=0.2

In [3]:
models = clip.available_models()
print(models)
model, preprocess = clip.load('RN50', device, jit=False)

['RN50', 'ViT-B/32']


## Create Dataset

In [4]:
class TextImg_Dataset(Dataset):
    def __init__(self, data_dir, images_dir, images, titles):
        self.data_dir = data_dir
        self.images_dir = images_dir
        self.images_paths = images
        self.titles = titles
        assert (len(titles)==len(images))

    def __len__(self):
        return len(self.images_paths)
    
    def __getitem__(self, index):
        txt = self.titles[index]
        text_tokens = clip.tokenize(txt, context_length=CONTEXT_LENGTH).squeeze().to(device) # torch.Size([77])
        # print(type(text_tokens), text_tokens.shape)
        img = self.images_paths[index]
        path = os.path.join(self.data_dir, self.images_dir, img)
        image = preprocess(Image.open(path)).to(device) # torch.Size([3, 224, 224])
        # print(type(image), image.shape)
        return text_tokens, image


In [5]:
summary_bert_path = "ComputerVision_Data/Summaries/Summary_Bert.csv"
data_dir = "ComputerVision_Data"
images_dir = "Food Images/Food Images"

In [6]:
summary_df = pd.read_csv(summary_bert_path)
liste_images = summary_df["Image_Name"].tolist()
liste_textes = summary_df["Summary"].tolist()
if SAMPLE_SIZE != None:
    liste_images=liste_images[:SAMPLE_SIZE]
    liste_textes=liste_textes[:SAMPLE_SIZE]
liste_images = [image + ".jpg" for image in liste_images]

In [7]:
liste_images[0], liste_textes[0]

('miso-butter-roast-chicken-acorn-squash-panzanella.jpg',
 'Roast chicken in a large castiron skillet and roast on middle rack until an instant')

In [8]:
assert(len(liste_textes) == len(liste_images))
length = len(liste_images)

train_size = int(length*TRAIN_PROP)
val_size = int(length*VAL_PROP)

In [9]:
train_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[:train_size], liste_textes[:train_size])
val_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[train_size:train_size+val_size], liste_textes[train_size:train_size+val_size])
test_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[train_size+val_size:], liste_textes[train_size+val_size:])

In [10]:
train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_dl = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

In [11]:
len(train_dl), next(iter(train_dl))[0]

(362,
 tensor([[49406,   589,  2972,   753,   655,  1105,   593,  4816,  1504, 17379,
            269,   518,  1033, 17379,   328,  1311,   783,  1483, 49407,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [49406,   530,   320, 25066, 12919,   615,   518, 11788,  8541,   518,
           5574,   537, 10780,  1123,   269,  3536,  5574,   541, 49407,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,   

## Fine Tune CLIP

In [12]:
def compute_loss(text_embeddings, image_embeddings):
    logits = (text_embeddings @ image_embeddings.T)
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = F.softmax(
        (images_similarity + texts_similarity) / 2, dim=-1
    )
    texts_loss = cross_entropy(logits, targets, reduction='none')
    images_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
    return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = torch.nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

In [13]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE,betas=BETAS,eps=EPS,weight_decay=WEIGHT_DECAY)
train_losses = []
val_losses = []

clip.model.convert_weights(model)

for ep in range(3):
    model.train()
    train_loss = 0
    for batch in tqdm(train_dl): 
        optimizer.zero_grad()
        text, img = batch
        # print(text.shape, img.shape)
        text_features = model.encode_text(text)
        image_features = model.encode_image(img)
        loss = compute_loss(text_features, image_features)
        train_loss+=loss.item()
        loss.backward()
        optimizer.step()
    train_losses.append(train_loss)

    # VALIDATION LOOP
    model.eval()
    val_loss = 0
    for batch in tqdm(val_dl): 
        text, img = batch
        # print(text.shape, img.shape)
        text_features = model.encode_text(text)
        image_features = model.encode_image(img)
        loss = compute_loss(text_features, image_features)
        val_loss+=loss.item()
    val_losses.append(val_loss)
        
        

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 362/362 [00:26<00:00, 13.41it/s]
100%|██████████| 104/104 [00:02<00:00, 40.16it/s]
100%|██████████| 362/362 [00:25<00:00, 14.13it/s]
100%|██████████| 104/104 [00:02<00:00, 41.28it/s]
100%|██████████| 362/362 [00:25<00:00, 14.08it/s]
100%|██████████| 104/104 [00:02<00:00, 41.33it/s]


In [14]:
train_losses, val_losses

([305.259033203125, 257.8232421875, 255.2978515625],
 [78.88720703125, 80.77978515625, nan])

In [15]:
model.eval()
val_loss = 0
for batch in tqdm(val_dl): 
    text, img = batch
    # print(text.shape, img.shape)
    text_features = model.encode_text(text)
    image_features = model.encode_image(img)
    print(text_features, image_features)
    loss = compute_loss(text_features, image_features)
    print(loss)
    val_loss+=loss.item()
val_losses.append(val_loss)

  7%|▋         | 7/104 [00:00<00:03, 31.47it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0080, -0.0096, -0.0532,  ...,  0.0170,  0.0515,  0.0333],
        [ 0.3955, -0.2396, -2.9629,  ...,  1.2734,  3.5762,  1.4014],
        [ 0.0078, -0.0091, -0.0533,  ...,  0.0178,  0.0508,  0.0313],
        [ 0.0060, -0.0075, -0.0431,  ...,  0.0128,  0.0414,  0.0286]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 

 14%|█▍        | 15/104 [00:00<00:02, 35.19it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0101, -0.0105, -0.0628,  ...,  0.0211,  0.0608,  0.0385],
        [ 0.0349, -0.0177, -0.2295,  ...,  0.0948,  0.2338,  0.1113],
        [ 0.0380, -0.0238, -0.4919,  ...,  0.2415,  0.5161,  0.1616],
        [ 0.2350, -0.1681, -1.9238,  ...,  0.9160,  2.2891,  0.8350]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.6973, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 

 22%|██▏       | 23/104 [00:00<00:02, 36.93it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0087, -0.0091, -0.0572,  ...,  0.0183,  0.0554,  0.0343],
        [ 0.4897, -0.2649, -3.6211,  ...,  1.6162,  4.4414,  1.6709],
        [ 0.5439, -0.3098, -4.1211,  ...,  1.7637,  5.0117,  1.9580],
        [ 0.0077, -0.0095, -0.0482,  ...,  0.0147,  0.0469,  0.0305]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3306,  0.0785],
        [ 

 30%|██▉       | 31/104 [00:00<00:01, 36.99it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1632, -0.0756, -0.3525,  ...,  0.2400,  0.3306,  0.0786],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3306,  0.0786]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0059, -0.0083, -0.0471,  ...,  0.0136,  0.0445,  0.0294],
        [ 0.0047, -0.0067, -0.0396,  ...,  0.0115,  0.0384,  0.0243],
        [ 0.0074, -0.0085, -0.0463,  ...,  0.0140,  0.0447,  0.0302],
        [ 0.0217, -0.0152, -0.1659,  ...,  0.0711,  0.1708,  0.0745]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.9009, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3306,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 

 38%|███▊      | 39/104 [00:01<00:01, 37.06it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0046, -0.0065, -0.0376,  ...,  0.0103,  0.0356,  0.0229],
        [ 0.0055, -0.0063, -0.0422,  ...,  0.0127,  0.0401,  0.0257],
        [ 0.0170, -0.0114, -0.0992,  ...,  0.0367,  0.0985,  0.0570],
        [ 0.0053, -0.0068, -0.0378,  ...,  0.0108,  0.0357,  0.0248]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(1.7676, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 

 45%|████▌     | 47/104 [00:01<00:01, 36.99it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 4.6289e-01, -2.7881e-01, -3.4746e+00,  ...,  1.5879e+00,
          4.2578e+00,  1.6084e+00],
        [ 5.2223e-03, -7.2174e-03, -3.9520e-02,  ...,  1.1642e-02,
          3.6469e-02,  2.5742e-02],
        [ 1.1572e+00, -7.0215e-01, -8.6406e+00,  ...,  3.6191e+00,
          1.0625e+01,  4.3984e+00],
        [ 5.2832e-01, -3.2837e-01, -3.9004e+00,  ...,  1.6650e+00,
          4.7578e+00,  1.8574e+00]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)
tensor(0.7471, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1632, 

 53%|█████▎    | 55/104 [00:01<00:01, 36.82it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2401,  0.3303,  0.0787]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.3853, -0.2175, -3.0098,  ...,  1.3672,  3.6562,  1.3867],
        [ 0.3782, -0.2253, -2.7266,  ...,  1.2295,  3.3262,  1.2568],
        [ 0.0055, -0.0087, -0.0401,  ...,  0.0117,  0.0392,  0.0259],
        [ 0.0286, -0.0203, -0.2410,  ...,  0.1083,  0.2578,  0.1029]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 

 61%|██████    | 63/104 [00:01<00:01, 36.99it/s]

tensor([[ 0.1632, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3525,  ...,  0.2400,  0.3306,  0.0786]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 6.2158e-01, -3.1226e-01, -4.4961e+00,  ...,  1.8955e+00,
          5.4805e+00,  2.1797e+00],
        [ 9.4873e-01, -4.8682e-01, -6.8398e+00,  ...,  2.8379e+00,
          8.3750e+00,  3.4355e+00],
        [ 3.4790e-01, -2.1326e-01, -2.5547e+00,  ...,  1.1963e+00,
          3.1094e+00,  1.1768e+00],
        [ 4.3526e-03, -6.8207e-03, -3.7628e-02,  ...,  1.1024e-02,
          3.5126e-02,  2.3331e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1632, 

 68%|██████▊   | 71/104 [00:01<00:00, 37.30it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 5.9624e-03, -6.5346e-03, -4.0436e-02,  ...,  1.1620e-02,
          3.8574e-02,  2.5284e-02],
        [ 1.0029e+00, -6.0596e-01, -7.5547e+00,  ...,  3.1660e+00,
          9.3359e+00,  3.7441e+00],
        [ 6.6147e-03, -8.4000e-03, -4.7546e-02,  ...,  1.5320e-02,
          4.7150e-02,  2.9205e-02],
        [ 7.2823e-03, -9.6436e-03, -5.4138e-02,  ...,  1.7899e-02,
          5.2734e-02,  3.4058e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0786],
        [ 0.1631, 

 76%|███████▌  | 79/104 [00:02<00:00, 36.91it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1632, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2401,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0043, -0.0076, -0.0379,  ...,  0.0114,  0.0358,  0.0242],
        [ 0.0055, -0.0089, -0.0457,  ...,  0.0135,  0.0432,  0.0281],
        [ 0.0050, -0.0069, -0.0377,  ...,  0.0106,  0.0356,  0.0230],
        [ 0.0052, -0.0082, -0.0406,  ...,  0.0117,  0.0380,  0.0253]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(1.4385, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0756, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 

 84%|████████▎ | 87/104 [00:02<00:00, 37.01it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 6.4795e-01, -3.7305e-01, -4.8438e+00,  ...,  2.0332e+00,
          5.9727e+00,  2.3828e+00],
        [ 6.3972e-03, -8.1329e-03, -4.6417e-02,  ...,  1.4389e-02,
          4.4128e-02,  2.8992e-02],
        [ 4.7379e-03, -6.6223e-03, -3.8300e-02,  ...,  1.1215e-02,
          3.5919e-02,  2.4323e-02],
        [ 7.4081e-03, -8.7585e-03, -4.8370e-02,  ...,  1.5198e-02,
          4.8248e-02,  3.0685e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)
tensor(0.7085, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, 

 91%|█████████▏| 95/104 [00:02<00:00, 36.84it/s]

tensor([[ 0.1632, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0756, -0.3523,  ...,  0.2400,  0.3303,  0.0787]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 7.9575e-03, -8.3847e-03, -5.2887e-02,  ...,  1.6342e-02,
          5.1575e-02,  3.2776e-02],
        [ 1.1504e+00, -6.3867e-01, -8.6328e+00,  ...,  3.6055e+00,
          1.0508e+01,  4.3750e+00],
        [ 3.9520e-03, -6.8283e-03, -3.8391e-02,  ...,  1.1200e-02,
          3.5278e-02,  2.3468e-02],
        [ 5.2910e-03, -7.5150e-03, -4.1199e-02,  ...,  1.1734e-02,
          3.9154e-02,  2.4567e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)
tensor(0.7319, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 0.1631, 

100%|██████████| 104/104 [00:02<00:00, 36.65it/s]

tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3306,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0785]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>) tensor([[ 0.0063, -0.0083, -0.0441,  ...,  0.0129,  0.0425,  0.0276],
        [ 0.6826, -0.3860, -5.0625,  ...,  2.1270,  6.2539,  2.4824],
        [ 0.0069, -0.0090, -0.0461,  ...,  0.0137,  0.0457,  0.0288],
        [ 0.4517, -0.2664, -3.3145,  ...,  1.4492,  4.0391,  1.5586]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
tensor(0.6934, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
tensor([[ 0.1631, -0.0755, -0.3523,  ...,  0.2399,  0.3303,  0.0785],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0786],
        [ 0.1631, -0.0755, -0.3523,  ...,  0.2400,  0.3303,  0.0787],
        [ 


