In [5]:
from tqdm.notebook import tqdm
import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

clip_model, clip_preprocess = clip.load("ViT-B/32", device)

cuda


In [6]:
from datasets import *
dataset_obj = Flowers102(0, 50)
train_loader, _ = dataset_obj.get_train_loaders(transform_fn=clip_preprocess)
test_loader = dataset_obj.get_test_loader(transform_fn=clip_preprocess)
classes = dataset_obj.classes

In [7]:
import copy

In [8]:
def get_clip_features(dataset):
    all_features = []
    all_labels = []

    global clip_model

    with torch.no_grad():
        for images, labels in tqdm(dataset):
            features = clip_model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels)

train_features, train_labels = get_clip_features(train_loader)
test_features, test_labels = get_clip_features(test_loader)

  0%|          | 0/132 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

In [9]:
def batch(iterable1,iterable2, n=1):
    l = len(iterable1)
    for ndx in range(0, l, n):
        yield (iterable1[ndx:min(ndx + n, l)], iterable2[ndx:min(ndx + n, l)])

out_train = batch(train_features,train_labels, 50)
out_test = batch(test_features,test_labels, 50)

In [10]:
import torch.nn as nn
import torch.optim as optim

class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        outputs = self.linear(x)
        return outputs
    
model = LogisticRegression(512,len(classes))
model.logit_scale = nn.Parameter(torch.ones([], device=device))
criterion = nn.CrossEntropyLoss()
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
num_epochs = 500
# scheduler = optim.lr_scheduler.MultiStepLR(
#             optimizer, milestones=[300, 500, 700, 900], gamma=0.1
#         )

In [11]:
def num_correct_preds(outputs, labels):
    _, predicted = outputs.max(1)
    return predicted.eq(labels).sum().item()

In [12]:
def cosine_loss(output, target):
    loss = 1 - torch.cosine_similarity(output, target)
    return loss

def calc_loss(outputs, labels, loss_name="ce"):
    
    loss_labels = labels
    
    criterion = nn.CrossEntropyLoss()

    if loss_name == "ce":
        return criterion(outputs, loss_labels) 

    elif loss_name == "dot":
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)
        return -(outputs * loss_labels).sum(-1).mean()

    elif loss_name == "cosine":
        print(outputs.shape, loss_labels.shape)
        loss = torch.mean(cosine_loss(outputs, loss_labels))
        return loss

    elif loss_name == "temperature_ce":
        image_features = outputs

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        return criterion(logits_per_image, labels)

In [13]:
# import torch
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter()

In [15]:
import random

model.to(device)
model = model.to(torch.float16)

best_model = None
best_acc = -1

for epoch in tqdm(range(num_epochs+1)):
    
    model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0

    if epoch%50 == 0:
        print("Shuffling")
        c = list(zip(train_features, train_labels))
        random.shuffle(c)
        train_features, train_labels = zip(*c)
        train_features = torch.stack(list(train_features))
        train_labels = torch.stack(list(train_labels))
        
    out_train = batch(train_features,train_labels, 50)
    for inputs, labels in out_train:
        
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)
        print(outputs, labels)
        loss = calc_loss(outputs, labels)

        loss.backward()

        optimizer.step()
        running_loss+=loss
        total += len(labels)
        correct += num_correct_preds(outputs, labels)
        
    if epoch%10 == 0:
        model.eval()
        out_test = batch(test_features,test_labels, 50)
        
        test_running_loss = 0.0
        test_correct = 0.0
        test_total = 0
        with torch.no_grad():
            for inputs, labels in tqdm(out_test):
        #         inputs = inputs / inputs.norm(dim=-1, keepdim=True)
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_running_loss += loss.item()
                test_total += labels.size(0)
                test_correct += num_correct_preds(outputs, labels)

        epoch_loss = test_running_loss / len(test_loader)
#         writer.add_scalar("Loss/val", epoch_loss, epoch)
        epoch_accuracy = test_correct * 100 / test_total
#         writer.add_scalar("Accuracy/val", epoch_loss, epoch)
        if epoch_accuracy > best_acc:
            best_model= copy.deepcopy(model)
            best_acc = epoch_accuracy
            print("Found best model")
#         print(
#             f"Testing: Epoch {epoch} || Loss: {epoch_loss:7.3f} || Accuracy: {epoch_accuracy:6.2f}%"
#         )

        
    epoch_loss = running_loss/len(train_loader)
#     writer.add_scalar("Loss/train", epoch_loss, epoch)
    epoch_accuracy = correct*100/total
#     writer.add_scalar("Accuracy/train", epoch_loss, epoch)
    print(
        f"Training: Epoch {epoch} || Loss: {epoch_loss:7.3f} || Accuracy: {epoch_accuracy:6.2f}%"
    )

#     writer.flush()
# writer.close()


  0%|          | 0/501 [00:00<?, ?it/s]

Shuffling
tensor([[-0.3679, -0.0978, -0.5479,  ..., -0.2454,  0.3862, -0.0397],
        [-0.4412, -0.2700, -0.3984,  ..., -0.2888,  0.0526, -0.1776],
        [-0.4080, -0.4868, -0.3340,  ..., -0.2244, -0.1191,  0.0240],
        ...,
        [-0.4497, -0.2360, -0.6748,  ..., -0.1409, -0.0643, -0.1108],
        [-0.3154, -0.4355, -0.4141,  ..., -0.0940, -0.1173, -0.1974],
        [-0.4275, -0.2484, -0.2269,  ..., -0.1661,  0.1283, -0.2239]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.4355, -0.1898, -0.4900,  ..., -0.1119,  0.0995,  0.1312],
        [-0.4509, -0.5264, -0.2952,  ...,  0.0543, -0.1660,  0.0556],
        [-0.7319, -0.4072, -0.5239,  ..., -0.3

tensor([[-0.4373, -0.3540, -0.5483,  ..., -0.1451, -0.2379, -0.2269],
        [-0.4099, -0.2698, -0.2001,  ..., -0.3721, -0.2338, -0.3911],
        [-0.4871, -0.4941, -0.3914,  ..., -0.1663, -0.0020, -0.0408],
        ...,
        [-0.3879, -0.5239, -0.3579,  ..., -0.1328, -0.2047, -0.1450],
        [-0.5493, -0.3093, -0.4033,  ..., -0.2477,  0.1893,  0.0130],
        [-0.5942, -0.6270, -0.4426,  ..., -0.2157,  0.1630,  0.0097]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 93,   9,  86,  32,  54,  77,  94,  84,  64,  73,  23,  96,  18,  55,
         11,  72,  32,  87,  87,  15,  83,  40,  72,   5,  37,  33,  72,  61,
         40,  77,   5,  41,  83,  79,   5, 100,  49,  91,  26,  77,  57,  49,
         44,  56,  21,  19,  74,  51,   8,  82], device='cuda:0')
tensor([[-0.6655, -0.6221, -0.5337,  ..., -0.1152,  0.0575, -0.4243],
        [-0.4248, -0.1993, -0.5283,  ..., -0.2274, -0.1708, -0.0638],
        [-0.6372, -0.2812, -0.0993,  ..., -0.1088,  0.03

tensor([[-0.5449, -0.1512, -0.3894,  ..., -0.1322,  0.0071,  0.2236],
        [-0.7153, -0.2993, -0.0610,  ..., -0.1519,  0.2288, -0.1819],
        [-0.5986, -0.6348, -0.4434,  ..., -0.0690,  0.5205, -0.1725],
        ...,
        [-0.6797, -0.1755, -0.3716,  ..., -0.3562,  0.1998, -0.0981],
        [-0.6802, -0.4692, -0.5190,  ..., -0.2000, -0.2249, -0.1918],
        [-0.8755, -0.4890, -0.5347,  ..., -0.0594,  0.1694, -0.3596]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 77,  48,  95,  80,  96,  42,  36,  85,  60,  49,  26,  80,  45,  68,
         57,  85,  75,  22,  42,  27,  73,  87,  56,  84,  96,  77,  43,  57,
         85,  51,  32, 101,   5,  11,  56,  96,  92,  36,  75,  64,  74,  35,
         17,  89,  33,  82,  11,  62,  68,  55], device='cuda:0')
tensor([[-0.5327, -0.5151, -0.5254,  ..., -0.1924,  0.1123, -0.1154],
        [-0.6406, -0.7085, -0.5225,  ..., -0.1633,  0.3484, -0.1145],
        [-0.5615, -0.5107, -0.4443,  ..., -0.1648, -0.16

0it [00:00, ?it/s]

Found best model
Training: Epoch 0 || Loss:   4.250 || Accuracy:  12.70%
tensor([[-0.5391, -0.1203, -0.5479,  ..., -0.2428,  0.3557, -0.0558],
        [-0.6079, -0.4202, -0.4768,  ..., -0.2551,  0.0973, -0.2026],
        [-0.5332, -0.5771, -0.3918,  ..., -0.1941, -0.1130, -0.0932],
        ...,
        [-0.5718, -0.2705, -0.7163,  ..., -0.1710, -0.0240, -0.1954],
        [-0.4609, -0.5708, -0.4790,  ..., -0.0855, -0.0532, -0.2683],
        [-0.5972, -0.3799, -0.2969,  ..., -0.1443,  0.1581, -0.2581]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.5962, -0.2529, -0.5244,  ..., -0.1220,  0.1285,  0.1156],
        [-0.5669, -0.6118, -0.3391,  ...,  0.0746, -0

tensor([[-0.5376, -0.2605, -0.4475,  ..., -0.1575,  0.1360, -0.1863],
        [-0.6226, -0.5396, -0.4641,  ..., -0.0277,  0.2422, -0.1029],
        [-0.6323, -0.7251, -0.5991,  ..., -0.0152,  0.1406, -0.1688],
        ...,
        [-0.4895, -0.4966, -0.3398,  ...,  0.0774, -0.0634, -0.2035],
        [-0.4941, -0.3960, -0.5024,  ..., -0.1071,  0.0365,  0.0447],
        [-0.6265, -0.2078, -0.4495,  ...,  0.0328,  0.1482, -0.0807]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([79, 60, 84, 51, 43, 92, 80, 13, 94, 96, 54, 73, 89, 54, 48, 38, 66, 80,
        77, 67, 94, 60, 36, 15, 49, 49, 10, 51, 90, 52, 85, 45, 84, 79, 83, 58,
        59, 39, 40, 49, 58, 70, 24, 23, 51, 97, 47, 14, 77, 77],
       device='cuda:0')
tensor([[-0.5854, -0.4753, -0.5703,  ..., -0.1180, -0.0998, -0.0752],
        [-0.3926, -0.6377, -0.4661,  ..., -0.0809,  0.0408,  0.0536],
        [-0.4922, -0.5405, -0.5674,  ..., -0.1464,  0.0805, -0.2255],
        ...,
        [-0.5557, -0.469

tensor([[-0.7651, -0.4844, -0.6655,  ...,  0.0280,  0.1321, -0.3125],
        [-0.7480, -0.4788, -0.4451,  ..., -0.2595,  0.0989, -0.3821],
        [-0.7544, -0.5288, -0.5425,  ..., -0.1267,  0.1824, -0.1538],
        ...,
        [-0.6753, -0.4353, -0.5479,  ..., -0.0043,  0.0854,  0.0177],
        [-0.7407, -0.7935, -0.8242,  ..., -0.0832,  0.1710, -0.2632],
        [-0.5918, -0.8447, -0.7700,  ..., -0.1968,  0.2778, -0.1621]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  43,  49,  17,  22,  21, 101,  78,  85,  74,  38,  98,  54,  77,
         59,  15,  41,  31,  78,  20,  73,  38,  11,  92,  49,  40,  73,  19,
         49,  79,  85, 101,  47,  88,  37,  90,  53,  41,  66,  90,  89,  92,
         49,  99,  90,  88,  25,  96,  15,  78], device='cuda:0')
tensor([[-0.7939, -0.2291, -0.4329,  ...,  0.0665,  0.2715, -0.0969],
        [-0.6294, -0.7041, -0.4678,  ..., -0.0338, -0.0942, -0.1531],
        [-0.6841, -0.4065, -0.4880,  ..., -0.0323,  0.15

tensor([[-0.7476, -0.2751, -0.1510,  ..., -0.0081, -0.1005, -0.2625],
        [-0.7261, -0.3647, -0.5835,  ..., -0.1879,  0.0648, -0.2020],
        [-0.7207, -0.7358, -0.5093,  ..., -0.0682,  0.2372, -0.2313],
        ...,
        [-0.6519, -0.4954, -0.4976,  ..., -0.2422,  0.3379, -0.3813],
        [-0.6797, -0.6704, -0.4863,  ..., -0.2003, -0.0328, -0.2861],
        [-0.8481, -0.5420, -0.5747,  ..., -0.3848,  0.2168, -0.2158]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  6,  90,  84,  31,   1,  73,  82,  65,   6,  88,  31,  97,  81,  50,
         78,  84,  56,  74,  90,  23,  83,  89,  27,  83,  82,  24,  12,  84,
         56,  10,  23, 100,  96,   9,  70,  65,  72,  19,  40,  69,  49,  49,
         67,  59,  82,  30,  65, 100,  50,  28], device='cuda:0')
tensor([[-0.6475, -0.3704, -0.3887,  ..., -0.2487,  0.0259, -0.0497],
        [-0.6909, -0.1506, -0.1462,  ...,  0.0199,  0.1088, -0.0495],
        [-0.4011, -0.5645, -0.6460,  ..., -0.1995, -0.14

tensor([[-0.6724, -0.4868, -0.5464,  ..., -0.0281, -0.1172, -0.1340],
        [-0.7344, -0.7500, -0.4722,  ..., -0.1449,  0.4104, -0.2651],
        [-0.6772, -0.4656, -0.3442,  ..., -0.0538,  0.0945, -0.0812],
        ...,
        [-0.7900, -0.4797, -0.4775,  ..., -0.0920,  0.1975, -0.2673],
        [-0.6973, -0.7793, -0.7988,  ..., -0.1655,  0.6064,  0.0745],
        [-0.6118, -0.5098, -0.4338,  ..., -0.1606, -0.0173, -0.0632]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  53,  69,  81,  61,  18,  78,  33,  39,  43,  65,  50,  90,  86,
         55,  15,  83,  32,  14,  77,  33,  82,  62,  24,  25,  84,  51,   2,
         56,  73,  89, 101,  14,  97,   9,  26,  49,  48,  85,  97,  49,  89,
         38,  77,  43,  64,  24,  72,  19,  72], device='cuda:0')
tensor([[-1.1465e+00, -6.1084e-01, -3.5669e-01,  ..., -2.2583e-01,
          2.6001e-01, -3.1519e-01],
        [-8.1201e-01, -7.1045e-01, -3.5547e-01,  ..., -2.6709e-01,
          5.9387e-02, -1.7

tensor([[-7.5732e-01, -7.0996e-01, -6.3379e-01,  ..., -3.2227e-01,
          9.0576e-02, -1.3025e-01],
        [-6.8555e-01, -2.7490e-01, -2.2913e-01,  ..., -2.9150e-01,
          1.7468e-01,  1.9073e-02],
        [-6.7090e-01, -3.5913e-01, -4.8315e-01,  ...,  3.8818e-02,
          4.7821e-02, -1.6833e-01],
        ...,
        [-5.5615e-01, -5.3320e-01, -6.8896e-01,  ...,  3.3092e-04,
          1.3599e-01, -5.3558e-02],
        [-7.4561e-01, -3.2104e-01, -3.6401e-01,  ..., -2.7002e-01,
          1.2109e-01, -6.9458e-02],
        [-7.6709e-01, -5.7617e-01, -5.9131e-01,  ..., -2.0569e-01,
          3.7744e-01, -3.4521e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([82, 52, 23, 42, 49, 26, 58, 20, 90, 97, 18, 11, 46, 79, 35, 51, 77, 96,
        98,  9, 88, 10, 23, 49, 89, 25, 82, 43, 90, 73, 47, 69, 15, 56, 11, 18,
        84, 23, 83, 12, 13, 43, 47, 51, 63, 73, 92, 76, 62, 25],
       device='cuda:0')
tensor([[-0.7559, -0.4717, -0.5977,  ..., -0.0324

tensor([[-0.7808, -0.7500, -0.4316,  ...,  0.0591, -0.1183, -0.3083],
        [-0.7046, -0.6611, -0.6089,  ...,  0.0212,  0.0677, -0.3198],
        [-0.8203, -0.3560, -0.5610,  ..., -0.1364,  0.2356, -0.1913],
        ...,
        [-0.7583, -0.5972, -0.3850,  ..., -0.2144,  0.1064, -0.2001],
        [-0.6748, -0.6514, -0.6616,  ...,  0.1354, -0.0955, -0.1483],
        [-0.5586, -0.6997, -0.6025,  ..., -0.0371,  0.2456, -0.4041]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 51,  96,  94,  49,  20,  89,  55,  98,  78,  49,  32,  98,  64,  74,
         87,   8,  17,  77,  49,   5,  75,  24,  38,  89,  89,   2,  11,  76,
         42,  88,  64,  28,  59,  98,  89,   3,  84,  98,  87,  43,  33,  48,
         70,  40,  33, 100,  78,   4,  99,  53], device='cuda:0')
tensor([[-0.6909, -0.7456, -0.6089,  ..., -0.2252,  0.1899, -0.0672],
        [-0.7002, -0.4773, -0.3123,  ..., -0.2430,  0.0654, -0.1362],
        [-0.7017, -0.8872, -0.6255,  ..., -0.1140,  0.39

tensor([[-0.6094, -0.5483, -0.5469,  ..., -0.2520,  0.1992, -0.3477],
        [-0.8271, -0.3145, -0.5054,  ..., -0.1383,  0.2485, -0.3245],
        [-0.5293, -0.3269, -0.2922,  ..., -0.2939, -0.0267, -0.1138],
        ...,
        [-0.8677, -0.6592, -0.5161,  ..., -0.1492,  0.0593, -0.2361],
        [-0.7490, -0.9062, -0.4897,  ..., -0.2998,  0.2585, -0.2419],
        [-0.6646, -0.4695, -0.4285,  ..., -0.2352, -0.0763, -0.1504]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 13,  79,  46,  90,  84,  74,  44,   3,  19,  56,  82,  72,  66,  32,
         61,  73,  70,  83,  48,  18,  77,  53,  89,  81,   3,  56,  11,  72,
         43,  18,  13,  95,  77,  18,  43,  48,  75,  85,  52,  36,  22, 100,
         96,  93, 100,  72,  55,  98,  50,  38], device='cuda:0')
tensor([[-0.5977, -0.8208, -0.3982,  ..., -0.1454,  0.0869,  0.0056],
        [-0.8022, -0.8804, -0.5337,  ..., -0.0972,  0.2913, -0.3364],
        [-0.9248, -0.7158, -0.3025,  ...,  0.0370,  0.29

tensor([[-0.6890, -0.4866, -0.4434,  ..., -0.0799,  0.4812, -0.1495],
        [-0.7285, -0.4980, -0.3105,  ..., -0.1949,  0.2881, -0.2162],
        [-0.7441, -0.8296, -0.6167,  ..., -0.0798,  0.3701, -0.2598],
        ...,
        [-0.7734, -0.5459, -0.5391,  ...,  0.0013, -0.1088, -0.1594],
        [-0.8853, -0.7080, -0.6362,  ..., -0.2477,  0.1136, -0.2017],
        [-0.7134, -0.5898, -0.7041,  ..., -0.1680,  0.0856, -0.2159]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 32,  49,  72,  92,  37,  92,  20,  63,  74,  86,  71,   5,  41,  62,
          5,  82,   7,  29,  53, 100,  90,  49,  20,  17,  87,  78,  88,  24,
         44,  84,  56,  83,  42,  26,  47,  57,  43,  62,  18,  56,  19,   6,
         51,  63,  98,  85,   6,  37,  43,  89], device='cuda:0')
tensor([[-0.8091, -0.3618, -0.4182,  ..., -0.1201, -0.0964, -0.2974],
        [-0.6978, -0.7925, -0.8369,  ...,  0.0743,  0.2407, -0.0514],
        [-0.7334, -0.5312, -0.6396,  ..., -0.2317, -0.14

tensor([[-0.6167, -0.0828, -0.5000,  ..., -0.1892,  0.1818,  0.0675],
        [-0.8589, -0.3501, -0.4314,  ...,  0.0055,  0.1311,  0.1799],
        [-0.7759, -0.6821, -0.5200,  ..., -0.3079,  0.1086, -0.1677],
        ...,
        [-0.6499, -0.5239, -0.5415,  ..., -0.2666, -0.0020, -0.3118],
        [-0.7539, -0.6567, -0.5835,  ..., -0.2124,  0.1086, -0.4358],
        [-0.6992, -0.6841, -0.4309,  ...,  0.0903,  0.1219, -0.0770]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([29, 77, 72, 47, 93, 47, 89, 84,  5, 18, 29,  1,  2, 26, 89, 84, 77,  6,
        49, 44, 44, 32, 14, 77, 49, 88, 46, 90, 59, 22, 56, 49, 24, 41, 51, 94,
        32, 61, 34, 77, 28, 88, 74, 80, 19, 76, 50, 43, 49, 36],
       device='cuda:0')
tensor([[-0.7686, -0.3628, -0.4563,  ..., -0.1909,  0.4128, -0.2698],
        [-0.7461, -0.9634, -0.6055,  ..., -0.1503,  0.1522, -0.3623],
        [-0.7944, -0.5327, -0.5537,  ..., -0.1851, -0.0972, -0.1870],
        ...,
        [-0.8013, -0.637

tensor([[-0.7134, -0.6304, -0.5830,  ..., -0.2939,  0.0659, -0.2031],
        [-0.7930, -0.6978, -0.4956,  ..., -0.1885, -0.0836, -0.3684],
        [-0.5864, -0.5972, -0.6841,  ..., -0.4358,  0.6157, -0.2183],
        ...,
        [-0.7324, -0.4336, -0.4268,  ..., -0.1331,  0.0939, -0.1777],
        [-0.6094, -0.5210, -0.6572,  ...,  0.3284, -0.2238, -0.1332],
        [-0.9380, -0.5225, -0.3499,  ..., -0.3394, -0.0961, -0.1548]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,  50, 100,  49,  45, 100,   7,  40,  99,  18,  44,  97,  24,  43,
         26,  84,  89,  54,  78,  67,  26,  84,  64,  72,  34,  41,  43,  33,
         84,  56,  78,  94,  72,  18,  40,  73,  89,  71,   7,  43,  75,  88,
         55,  54,  40,  78,  73,  18,  99,  56], device='cuda:0')
tensor([[-6.4600e-01, -2.5122e-01, -3.6401e-01,  ..., -3.8062e-01,
         -5.2734e-01, -2.1045e-01],
        [-6.3818e-01, -7.0166e-01, -6.1865e-01,  ..., -2.9834e-01,
          3.1006e-01, -5.1

tensor([[-0.8984, -0.7339, -0.5479,  ...,  0.0845,  0.0494, -0.1454],
        [-0.5063, -0.7520, -0.4941,  ..., -0.1642, -0.0872, -0.1080],
        [-0.8843, -0.7432, -0.4363,  ..., -0.1129,  0.1356, -0.2615],
        ...,
        [-0.5586, -0.5635, -0.5928,  ..., -0.0594, -0.2244, -0.1570],
        [-0.6597, -0.5547, -0.4465,  ..., -0.0787,  0.0993, -0.1639],
        [-0.8164, -0.6846, -0.6797,  ..., -0.0014,  0.1345, -0.1036]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  3,  85,  56,  54, 100,  77,  55,  70,  32,  50,  76,  55,  48,  96,
         63,   1,  87,  92,  18,  77,  38,  41,  78,  76,  37,  95,   5,  80,
         50,  58,  89,  18,  51,  83,  45,  73,  82,  11,  85,  43,  77,  24,
         59,  50,  98,   8,  76,  61,  30,  78], device='cuda:0')
tensor([[-6.8213e-01, -5.4004e-01, -4.8950e-01,  ..., -1.5442e-01,
          3.0273e-01, -1.4282e-01],
        [-8.5400e-01, -6.8408e-01, -7.2705e-01,  ...,  3.8757e-02,
          3.6694e-01,  2.2

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([48, 54, 33, 79, 13, 93, 17, 43,  8, 60, 51, 74, 98, 83, 86, 93, 18, 60,
        67, 96, 46, 34, 96, 77, 78, 46,  8, 75, 72, 16, 90, 24, 24, 87, 70, 59,
        56, 50, 73, 51, 83, 82, 39, 80,  2, 59, 74, 74, 74, 26],
       device='cuda:0')
tensor([[-0.7056, -0.5474, -0.3118,  ..., -0.1364, -0.2759, -0.0207],
        [-0.7642, -0.6743, -0.3752,  ..., -0.0616,  0.6606, -0.0525],
        [-0.9058, -0.2891, -0.5273,  ..., -0.0051,  0.1095, -0.3518],
        ...,
        [-0.9551, -0.6787, -0.2820,  ..., -0.2043,  0.0571, -0.1857],
        [-0.6240, -0.0876, -0.3655,  ..., -0.2113, -0.0049, -0.0685],
        [-0.5479, -0.8159, -0.5596,  ...,  0.2537,  0.0536, -0.0534]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 54,  76,  26,  56,  37,  39,  83,  79,   1,  53,  28,  81,  59,  37,
         99,  87,  88,  42,  68,  61,  53,  52,  97,  71,  29,   3,  72,  75,
          8,  27,  77

tensor([[-0.7310, -0.2081, -0.4631,  ..., -0.1172,  0.2340,  0.2097],
        [-0.5977, -0.6538, -0.3145,  ...,  0.1664, -0.0731, -0.2175],
        [-1.0703, -0.6631, -0.7476,  ..., -0.1709,  0.3130,  0.0753],
        ...,
        [-0.7134, -0.6030, -0.5479,  ..., -0.0078,  0.2800,  0.2136],
        [-0.8706, -0.6914, -0.3586,  ..., -0.0090, -0.0046, -0.2632],
        [-0.7686, -0.4221, -0.6035,  ..., -0.0370,  0.0693, -0.2488]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([77, 45, 97, 10, 78, 39, 81, 97, 68, 67, 96, 72, 89, 49, 53,  5, 77, 20,
        74, 55, 68, 78, 33, 70, 12, 37, 31, 58, 73,  0, 79, 97, 44, 53, 74, 31,
        73, 11, 50, 14, 67, 61, 23, 26, 84, 47, 88, 41, 56, 96],
       device='cuda:0')
tensor([[-0.7236, -0.6479, -0.5547,  ..., -0.1859,  0.1951, -0.3665],
        [-0.6382, -0.7500, -0.6602,  ..., -0.0663,  0.4275, -0.0210],
        [-0.6514, -0.5947, -0.5010,  ...,  0.0738,  0.4922, -0.2903],
        ...,
        [-0.8154, -0.866

tensor([[-0.8579, -0.8755, -0.4988,  ..., -0.0421,  0.1373, -0.2224],
        [-0.7051, -0.2627, -0.4023,  ..., -0.1398, -0.2991, -0.0299],
        [-0.7729, -0.4495, -0.2115,  ...,  0.0270,  0.3989, -0.4658],
        ...,
        [-0.6802,  0.1060, -0.6187,  ...,  0.0507,  0.1414, -0.1071],
        [-0.6138, -0.5308, -0.3765,  ..., -0.0891, -0.4319, -0.0832],
        [-0.4958, -0.7319, -0.3481,  ...,  0.0387,  0.1199, -0.2859]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([16, 57, 67, 74, 80, 31, 38, 37, 83, 82, 57, 73, 83, 75, 75, 62, 49, 41,
        81, 56, 33, 24, 91, 89, 77, 74, 93, 40, 45, 84, 54, 77, 18, 18, 53, 88,
        38, 55,  5, 80, 88, 71,  7, 97, 79, 96, 74, 77, 12, 49],
       device='cuda:0')
tensor([[-0.6255, -0.1716, -0.6343,  ..., -0.0822, -0.0451, -0.3494],
        [-0.7651, -0.5747, -0.5659,  ..., -0.3689,  0.1478, -0.2362],
        [-0.5903, -0.3584, -0.4556,  ..., -0.0709,  0.0656, -0.0554],
        ...,
        [-0.8291, -0.751

tensor([[-0.9116, -0.0232, -0.7500,  ...,  0.1897,  0.0861, -0.0826],
        [-0.6226, -0.4102, -0.5322,  ..., -0.0959,  0.4094, -0.2148],
        [-0.8145, -0.2725, -0.3545,  ...,  0.0565,  0.1561,  0.0673],
        ...,
        [-1.0449, -0.4114, -0.1715,  ..., -0.2715,  0.0112, -0.0807],
        [-0.9395, -0.4480, -0.0526,  ..., -0.3164,  0.0028,  0.0330],
        [-0.7876, -0.2715, -0.2852,  ...,  0.2448,  0.2291,  0.2786]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 58,  32,  23,  93,  72,  19, 100,  52,  75,  26,  73,  57,  56,  51,
         33,  34,  74,  97, 100,   5,  10,  43,  54,  97,  69,  30,  49,  43,
         82,  95,  97,  34,  81,  92,  76,  39,  78,  34,  11,  41,  16,  18,
         41,  43,  44,   5,  80,  47,  52,  77], device='cuda:0')
tensor([[-0.7905, -0.2778, -0.0015,  ..., -0.3235,  0.0961, -0.0355],
        [-0.7334, -0.7183, -0.4966,  ..., -0.1417,  0.2113,  0.1105],
        [-0.6606, -0.5444, -0.4424,  ...,  0.0251,  0.71

tensor([[-0.5928, -0.5278, -0.6040,  ..., -0.1619,  0.4014, -0.1769],
        [-0.9146, -0.8208, -0.5117,  ..., -0.2130,  0.4543,  0.3979],
        [-0.6191, -0.5752, -0.5020,  ..., -0.0166, -0.1976, -0.4370],
        ...,
        [-0.8350, -0.6323, -0.6172,  ..., -0.1209,  0.3792, -0.2225],
        [-0.6997, -0.5181, -0.6440,  ..., -0.2637,  0.3533, -0.4255],
        [-0.7793, -0.6099, -0.5088,  ..., -0.2554, -0.0790, -0.0306]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 13,  80,  61, 100,  40,  33,  85,  49,  93,  75,  93,  62,  17,  57,
         78,  78,  77,  74,  24,  80,  33,  25,  29,  12,  78,   1,  85,  55,
         59,  90,  57,  51,  70,   0,  96,  15,  25,  18,  51,  47,  89,  82,
         38,  28,  90,  47,  53,  43,  28,  73], device='cuda:0')
tensor([[-0.8799, -0.6587, -0.4133,  ..., -0.3892,  0.1005, -0.1842],
        [-0.6450, -0.7134, -0.5239,  ..., -0.1680, -0.0234, -0.0234],
        [-0.8340, -0.0274, -0.0023,  ..., -0.3271,  0.18

tensor([[-0.7979, -0.8374, -0.7188,  ..., -0.0775,  0.5933, -0.1710],
        [-0.7422, -0.5366, -0.5312,  ..., -0.3606,  0.2571,  0.0681],
        [-0.7905, -0.6362, -0.7612,  ..., -0.2114,  0.1407,  0.0133],
        ...,
        [-0.7798, -0.5508, -0.7964,  ...,  0.0817,  0.0455, -0.2686],
        [-0.7432, -0.4106, -0.2964,  ..., -0.4116,  0.1237, -0.0482],
        [-0.6284, -0.6504, -0.5522,  ...,  0.1483, -0.1533, -0.3940]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([97, 32, 78, 70, 90, 65, 35, 40,  6,  2, 56, 56, 74, 79, 73, 75, 89, 49,
        26, 48, 96, 82, 23, 33, 90, 55, 84, 97, 53,  0, 90, 95, 75, 57, 43, 20,
        71, 15, 67, 78, 88, 85, 28,  0, 82, 18, 64, 96, 74, 61],
       device='cuda:0')
tensor([[-1.0947, -0.3291, -0.2581,  ..., -0.4363,  0.4365, -0.1342],
        [-0.6787, -0.3647, -0.7139,  ...,  0.0868,  0.3582, -0.3015],
        [-0.7666, -0.5479, -0.8765,  ..., -0.1910,  0.2871, -0.2693],
        ...,
        [-0.8267, -0.799

tensor([[-0.5923, -0.7354, -0.6572,  ..., -0.2133,  0.2571, -0.1879],
        [-0.7642, -0.3806, -0.4641,  ...,  0.0366,  0.1720, -0.0704],
        [-0.6187, -0.0520, -0.3530,  ..., -0.1400,  0.7451, -0.2971],
        ...,
        [-0.6689, -0.0614, -0.6475,  ..., -0.0422,  0.2262, -0.2864],
        [-0.6870, -0.6318, -0.7578,  ..., -0.1205,  0.2566,  0.0999],
        [-0.5737, -0.4529, -0.6177,  ..., -0.0441,  0.2454, -0.0835]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 72,  23,  67,  83,  47,  91,  19,  55,  78,  77,   7,  78,  64,  76,
         73,  55,   8,  80,  17,  82,  30,  92,  43,  78,  24,  44,  84,  44,
        100,  22,  57,  87,  76,  70,  71,  49,  17,  93,   6,  90,  71,  93,
         15,  86,  18,  96,  56,  17,  35,  33], device='cuda:0')
tensor([[-0.3208, -0.4365, -0.5542,  ..., -0.1158, -0.1616, -0.1298],
        [-0.6436, -0.5947, -0.6577,  ..., -0.4534,  0.1974, -0.0940],
        [-0.7271, -0.6904, -0.7139,  ...,  0.1317,  0.53

tensor([[-0.8447, -0.6260, -0.3772,  ..., -0.2922,  0.0551, -0.2898],
        [-0.6914, -1.0908, -0.5239,  ...,  0.0113,  0.4358, -0.2661],
        [-0.8340, -0.3293, -0.1285,  ..., -0.2712,  0.0339,  0.0046],
        ...,
        [-1.1377, -0.5342, -0.4141,  ..., -0.1591,  0.3022,  0.0718],
        [-0.6055, -0.4421, -0.0496,  ..., -0.0974, -0.1511, -0.0699],
        [-0.8042, -0.4558, -0.6807,  ...,  0.0367,  0.1836, -0.0257]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([74, 84, 38,  3, 78, 98, 36, 43, 49, 63, 74, 32, 32, 40, 84, 52, 41, 47,
        64, 92, 74, 69, 63, 97, 45, 53, 11, 22, 66, 73,  4, 84, 49, 49, 40, 92,
        58, 34, 21, 23, 12, 52,  5, 72,  2, 32, 61, 60, 38, 96],
       device='cuda:0')
tensor([[-6.9238e-01, -5.6494e-01, -5.3662e-01,  ..., -2.0740e-01,
          4.1553e-01, -6.2549e-01],
        [-8.0420e-01, -7.5684e-01, -3.3472e-01,  ...,  6.9275e-02,
          2.8896e-04, -3.7476e-01],
        [-6.0107e-01, -7.0166e-01, -4.626

tensor([[-0.5693, -0.0601, -0.3762,  ...,  0.0168,  0.4175,  0.2600],
        [-0.5811, -0.6714, -0.6118,  ..., -0.1616,  0.3311, -0.3923],
        [-0.8857, -0.7075, -0.5552,  ..., -0.2585,  0.2000, -0.0024],
        ...,
        [-0.7236, -0.3137, -0.5498,  ..., -0.1859,  0.0912, -0.5151],
        [-1.0137, -0.7729, -0.6753,  ..., -0.0298,  0.7314, -0.3804],
        [-0.8091, -0.1855, -0.3115,  ..., -0.2969, -0.2510, -0.1190]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([77, 50, 69, 76, 52, 33, 75, 46, 43, 21,  1, 79, 54, 94, 49, 33, 93, 86,
        23, 60, 73, 85, 74, 33, 76,  6, 76, 78, 84, 65, 73, 96, 33, 21, 10, 39,
        77, 91, 48, 74, 54, 39, 84, 77, 96, 96, 43, 70, 97, 57],
       device='cuda:0')
tensor([[-0.5640, -0.3093, -0.6289,  ..., -0.0833, -0.2073, -0.3699],
        [-0.8145, -0.4614, -0.4731,  ..., -0.1417,  0.1313, -0.3354],
        [-0.6572, -0.7075, -0.6797,  ..., -0.1055, -0.0221, -0.1050],
        ...,
        [-0.9014, -0.479

tensor([[-0.8320, -0.6934, -0.5581,  ..., -0.3264,  0.0183, -0.5146],
        [-0.7100, -0.3804, -0.5640,  ..., -0.3030,  0.0514, -0.2595],
        [-0.8369, -0.6919, -0.4675,  ..., -0.3376,  0.3728,  0.8384],
        ...,
        [-0.7720, -0.6641, -0.3904,  ..., -0.2054, -0.1334, -0.2505],
        [-0.8447, -0.2228,  0.0312,  ..., -0.2998,  0.1029, -0.1135],
        [-0.6865, -0.5425, -0.3438,  ..., -0.0364,  0.1571, -0.4148]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 64,   9, 101,  49,  84,  51,   6,  90,   5,  74,  97,  47,  90,  49,
         76,  18,  81,  42,  45,  72,  18,  96,  73,  77,  49,  89,  56,  90,
         72,  26,  90,  34,  40,  81,  36,  31,  74,  91,  71,  43,  54,  83,
         37,  67,  12,  57,  34,  64,  52,   9], device='cuda:0')
tensor([[-0.6699, -0.4958, -0.6172,  ..., -0.1477,  0.3574, -0.1351],
        [-0.9385, -0.0911, -0.3843,  ..., -0.1708, -0.2126,  0.0107],
        [-0.5444, -0.5996, -0.4727,  ...,  0.1364,  0.14

tensor([[-0.7656, -0.0763, -0.0235,  ..., -0.0358, -0.1415, -0.2585],
        [-0.7324, -0.8726, -0.5859,  ...,  0.4470,  0.2090, -0.1458],
        [-0.6602, -0.6875, -0.5449,  ..., -0.0568,  0.1242, -0.4490],
        ...,
        [-1.0312, -1.0527, -0.6050,  ..., -0.4707,  0.1370,  0.3728],
        [-0.6426, -0.2661, -0.1309,  ..., -0.2542, -0.2600, -0.0578],
        [-0.4419, -0.8618, -0.5386,  ..., -0.0332,  0.0069,  0.0777]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 2, 99, 10,  9, 77, 53, 24, 89,  1, 29, 84, 75,  1, 45, 36, 92, 10, 43,
         6, 99, 30, 26, 90, 74, 62, 49, 49, 62, 74, 65, 29, 96, 13, 14, 41, 43,
        83, 49, 74, 17, 73, 44, 73, 60, 90, 93, 80, 82, 47, 72],
       device='cuda:0')
tensor([[-0.5024, -0.4875, -0.4016,  ...,  0.1334, -0.1298, -0.3257],
        [-0.6958, -0.2006,  0.0291,  ..., -0.3003, -0.0553, -0.1830],
        [-0.6904, -0.3882, -0.5552,  ..., -0.3589, -0.1676,  0.0584],
        ...,
        [-0.9341, -0.642

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 9, 84, 49, 56, 95, 21, 96, 26, 24, 41, 71,  5, 77, 68, 54, 84, 73, 96,
        37, 64, 98, 40, 84, 33, 26, 64, 92, 38, 70, 89, 33, 49, 49, 77, 26, 72,
        81, 92, 54, 68, 49, 10, 16, 97, 50, 70, 85, 89, 77, 54],
       device='cuda:0')
tensor([[-0.5942, -0.6626, -0.8281,  ..., -0.2334,  0.0119, -0.1554],
        [-0.7852, -0.7568, -0.6182,  ..., -0.2317,  0.0024,  0.0572],
        [-0.5112, -0.2512, -0.5537,  ..., -0.0184,  0.0224,  0.0128],
        ...,
        [-0.7549, -0.1490, -0.0687,  ..., -0.4751, -0.2671,  0.0720],
        [-0.6938, -0.8901, -0.4070,  ...,  0.1189,  0.0038,  0.1869],
        [-0.7627, -0.3013, -0.4797,  ..., -0.1090,  0.5728, -0.4580]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 89,  73,  18,  50,  82,  64,  93,  20,  62,  33,  46,  43,  37,  73,
         89,  29,  86,  38,  49,  69,  49,  13,  43,  90,   6,   5,  61,  75,
         13,   4,  85

tensor([[-0.5752, -0.2917, -0.6055,  ...,  0.1664, -0.0557, -0.6802],
        [-0.8086, -0.7407, -0.5220,  ..., -0.0158,  0.1780, -0.2505],
        [-0.5571, -0.2925, -0.7051,  ...,  0.1311, -0.0482, -0.2883],
        ...,
        [-0.8477, -0.5181, -0.5825,  ..., -0.2539,  0.2115,  0.4199],
        [-0.6646, -0.6279, -0.5649,  ...,  0.1675,  0.7070, -0.4854],
        [-0.7959, -0.5371, -0.5420,  ..., -0.2869,  0.0733, -0.1282]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([21, 56, 37, 38, 26, 28, 59, 89, 82, 77, 66, 62, 43, 61, 70, 78, 93, 85,
        96, 85, 38,  7, 23, 98, 55, 33, 87, 80, 43, 37, 49, 33, 32, 77, 81, 93,
        93, 90, 55, 93, 13, 54, 49, 65, 98, 89, 87, 80, 76, 74],
       device='cuda:0')
tensor([[-0.7695, -0.0339, -0.3943,  ..., -0.0712,  0.1774,  0.1606],
        [-0.7739, -0.5020, -0.4905,  ..., -0.4231,  0.0751, -0.0693],
        [-0.8467, -0.4602, -0.5063,  ..., -0.3584, -0.0840,  0.0231],
        ...,
        [-0.7246, -0.161

tensor([[-0.8779, -0.0292, -0.1294,  ..., -0.3926,  0.2407, -0.2300],
        [-0.8613, -0.8418, -0.3162,  ...,  0.0819,  0.0369, -0.0321],
        [-0.7393, -0.4639, -0.5879,  ..., -0.1945,  0.1924,  0.2023],
        ...,
        [-0.7329, -0.8218, -0.2437,  ...,  0.2257,  0.5459, -0.4055],
        [-0.4385, -0.4702, -0.6133,  ...,  0.2415,  0.3728, -0.4224],
        [-0.8320, -0.1287, -0.3479,  ..., -0.0614,  0.2578,  0.3640]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([62, 12, 73, 73, 90, 60, 15, 95, 10, 57, 43, 74, 55, 60,  3, 38, 56, 82,
         5, 59,  9, 95, 40, 40, 28, 76, 82, 76, 49, 84, 80, 90, 24, 32, 14, 51,
        33, 49, 45, 76, 83, 90, 92, 72, 78, 20, 15, 90, 36, 77],
       device='cuda:0')
tensor([[-0.6733, -0.7095, -0.7900,  ...,  0.0037,  0.2957, -0.1031],
        [-0.9888, -0.3169, -0.0213,  ..., -0.2515,  0.1068, -0.0099],
        [-0.8218, -0.2446, -0.3647,  ..., -0.1964, -0.1643, -0.0876],
        ...,
        [-1.1885, -0.045

tensor([[-0.4211, -0.7578, -0.7373,  ..., -0.1305,  0.4094, -0.4539],
        [-0.6616, -0.7437, -0.4631,  ...,  0.3372,  0.2637, -0.4180],
        [-0.7158, -0.7095, -0.4880,  ..., -0.1069, -0.1874, -0.1704],
        ...,
        [-0.6143, -0.2313, -0.4287,  ..., -0.1663,  0.0453, -0.1473],
        [-0.5283, -0.4224, -0.3350,  ...,  0.0101,  0.2805, -0.1847],
        [-0.7832, -0.3027,  0.0045,  ..., -0.2009, -0.1075, -0.0537]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 87,  90,  74, 100,  39,   0,  43,  86,  49,  64,  47,  56,  60,  72,
         97,  14,  82,  93,  54,   2,  24,   2,  63,  68,  41, 101,  74,  44,
         16,   0,  26,   1,  37,  52,  55,  45, 100, 101,  59,  54,  78,  30,
         49,  79,  69,  54,  16,  65,  79,  47], device='cuda:0')
tensor([[-0.8442, -0.6738, -0.7368,  ...,  0.0157,  0.2983,  0.0699],
        [-0.6030, -0.6660, -0.4844,  ..., -0.1974,  0.3274,  1.1279],
        [-0.7646, -0.0428, -0.1991,  ..., -0.5098, -0.29

tensor([[-0.8794, -0.5210, -0.2878,  ..., -0.1182,  0.1317, -0.0902],
        [-0.8535, -0.0406,  0.0979,  ..., -0.4216, -0.2177, -0.0745],
        [-0.4980, -0.2639, -0.4978,  ...,  0.0176,  0.5767, -0.5571],
        ...,
        [-0.4668, -0.6025, -0.4685,  ..., -0.0880,  0.2920, -0.4929],
        [-0.9360, -0.0579, -0.1813,  ...,  0.2542,  0.3423, -0.2208],
        [-0.6470, -0.9575, -0.6206,  ...,  0.0284, -0.2534,  0.0392]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  4,  38,  25,  90,  37,  52,  68,  16,  41,  40, 100,  74,  41,   8,
         25,  45,  73,  80,  32,  55,  12,  82,  77,  78, 101,  49,  82,  32,
         89,  65,  18,  43,  57,  93,  14,  20,  72,  80,  48,  69,  62,  59,
         11,  76,  76,  39,  77,  49,  77,  86], device='cuda:0')
tensor([[-0.9253, -0.7202, -0.9248,  ..., -0.3672, -0.1768,  0.1504],
        [-0.7461, -0.3774,  0.1686,  ..., -0.3938, -0.1986,  0.0731],
        [-0.5981, -0.8066, -0.5205,  ...,  0.1085,  0.61

tensor([[-0.7739, -0.7852, -0.5107,  ..., -0.1146,  0.0045,  0.0671],
        [-0.9189, -0.8950, -0.6230,  ..., -0.3105,  0.3223,  0.0748],
        [-0.7192, -0.4746, -0.0169,  ..., -0.2275,  0.1406,  0.2334],
        ...,
        [-0.6094, -0.8345, -0.5049,  ..., -0.1895, -0.0736, -0.3101],
        [-0.7227, -0.1570, -0.1133,  ..., -0.2169,  0.1656, -0.1879],
        [-0.9985, -0.2566, -0.8594,  ...,  0.3513,  0.0072,  0.1824]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 73,  69,  38,  43,  38,  76,  95,  14,  81,  71,  63,  11,  98,  61,
         43,  70,  73,  78,  43,  50,  82,  12,   5,  94,  56,  62,  73,   1,
         77,  73,  97,  85,  24,  89,  33,  15,   0,  33,  43, 101,  44,  87,
         54,  75,  75,  78,  76,  12,  44,  58], device='cuda:0')
tensor([[-0.7236, -0.4407, -0.4185,  ...,  0.4700,  0.5039, -0.3137],
        [-0.6484, -0.5005, -0.3743,  ..., -0.2059, -0.3159,  0.0880],
        [-0.5859, -0.4519, -0.3677,  ...,  0.0618, -0.13

tensor([[-0.6909, -0.4995, -0.4993,  ..., -0.1394,  0.3140, -0.2471],
        [-0.8560, -0.5581, -0.6245,  ..., -0.1965,  0.5122,  0.0140],
        [-0.9634, -0.2905, -0.4207,  ...,  0.4707,  0.0589, -0.4177],
        ...,
        [-0.5234, -0.8599, -0.5210,  ...,  0.2311, -0.3782, -0.1218],
        [-0.7612,  0.0401, -0.0638,  ..., -0.0230,  0.0162,  0.5947],
        [-0.7944,  0.1940,  0.1465,  ..., -0.0489, -0.0288, -0.4043]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 55,  36,  10,  26,  26,  91,   3,  72,  96,   5, 100,  45,  39,  59,
         26,  96,  60,  64,  43,  56,  33,   3,  18,  92,  59,  83,  18,  42,
         57,  78,  77,  33,  80,  70,  64,  39,  92,  53,  85,  70,  93,  78,
         96,  13,  25,  93,  23,  85,   7,   6], device='cuda:0')
tensor([[-0.5596, -0.6919, -0.5269,  ..., -0.1479,  0.8140,  0.0221],
        [-1.0010, -0.5376, -0.1530,  ...,  0.0705, -0.3262, -0.1021],
        [-0.7334, -0.6582, -0.3813,  ..., -0.1445,  0.01

tensor([[-6.8164e-01, -6.3330e-01, -6.6211e-01,  ..., -2.7344e-01,
         -5.7648e-02, -4.1270e-04],
        [-4.7534e-01, -2.7246e-01, -4.4385e-01,  ...,  9.7046e-02,
         -4.4141e-01, -8.6670e-02],
        [-4.8706e-01, -4.0552e-01, -7.4268e-01,  ...,  2.3773e-02,
          6.5079e-03,  2.6245e-03],
        ...,
        [-4.8755e-01, -3.8013e-01, -5.8301e-01,  ...,  4.0833e-02,
          3.2031e-01, -6.0938e-01],
        [-7.7490e-01, -2.8784e-01, -1.4746e-01,  ..., -1.2903e-01,
         -4.8120e-01, -5.4901e-02],
        [-6.4062e-01, -6.6113e-01, -8.3679e-02,  ..., -4.5197e-02,
          2.8735e-01, -8.3130e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([78, 18, 35, 99, 96, 28, 82, 53, 40, 77, 49, 43, 64,  7, 23, 64, 86, 84,
        59, 30, 22, 53, 23, 83, 83, 74,  6, 57, 59, 50,  8, 96, 75, 43, 89, 84,
        29, 43, 97,  7, 80, 77, 59, 96, 65, 15, 26, 36, 57, 40],
       device='cuda:0')
tensor([[-0.6821, -0.5806, -0.4312,  ...,  0.0927

tensor([[-0.2363, -0.4900, -0.5166,  ..., -0.1026, -0.1239, -0.0627],
        [-0.5923, -0.5322, -0.6167,  ..., -0.4988,  0.2482, -0.0605],
        [-0.6304, -0.6826, -0.7930,  ...,  0.2705,  0.7871, -0.3254],
        ...,
        [-0.5273,  0.0478, -0.3794,  ...,  0.1696,  0.2300, -0.2627],
        [-0.5728, -0.7310, -0.4504,  ..., -0.1417,  0.0425, -0.3245],
        [-0.7339, -0.8008, -0.6230,  ..., -0.1193, -0.4783,  0.4138]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 83,  66,  76,  78,  15,  90,  18,  70, 100,  37,  68,  24,   7,  28,
         73,  98,  32,  43,  96,  78,  64,  90,  75,  10,  64,  77,  93,  44,
         17,  82,  14,  26,   9,  96,  54,  52,  56,  30,  15,  74,  90,  40,
         75,  43,  12,  17,  90,  94,  50,  78], device='cuda:0')
tensor([[-0.7764, -0.8364, -0.5630,  ..., -0.0454,  0.9072,  0.0849],
        [-0.7129, -0.0917, -0.4229,  ...,  0.1344,  0.2964,  0.3740],
        [-0.5547, -0.3599, -0.3967,  ...,  0.1417, -0.50

tensor([[-0.8247, -0.5459, -0.2783,  ..., -0.2317,  0.2026,  0.0562],
        [-0.8164,  0.1776, -0.0356,  ...,  0.1026, -0.2094, -0.0515],
        [-0.9336, -1.0273, -0.7158,  ...,  0.3022,  0.5244,  0.4675],
        ...,
        [-0.8687, -0.8921, -0.5679,  ..., -0.1522,  0.5454, -0.2644],
        [-0.5024, -0.5098, -0.7061,  ...,  0.0290,  0.8335, -0.8413],
        [-0.5645, -0.8110, -0.2585,  ...,  0.1043, -0.1953, -0.7041]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,   6,  81, 101,  18,  11,  49,  75,  84,  24,  40,  27,  23,  87,
         87,  27,  74,  74,  34,  85,  97,  43,  98,  38,  54,  90,  79,  33,
         40,  62,  33,  96,  41,  13,  73,  72,  99,  97,  77,  74,   3,  54,
         87,  45,  80,  73,  77,  43,  36,  45], device='cuda:0')
tensor([[-0.9541, -0.8247, -0.3472,  ..., -0.1305,  0.5122,  1.0000],
        [-0.8325, -1.0537, -0.4521,  ...,  0.1947,  0.4656, -0.1445],
        [-0.4438, -0.9551, -0.6484,  ...,  0.0395,  0.54

tensor([[-0.5015, -0.2372, -0.6855,  ...,  0.0075, -0.2664, -0.4810],
        [-0.7124, -0.4744, -0.4949,  ..., -0.1174,  0.1403, -0.3821],
        [-0.6646, -0.7227, -0.6172,  ..., -0.0609, -0.1385, -0.0209],
        ...,
        [-0.8423, -0.4436, -0.7012,  ...,  0.4165,  0.1682, -0.5459],
        [-0.6738, -0.7812, -0.7275,  ..., -0.0310,  0.6465,  0.1050],
        [-0.9604, -1.0381, -0.4819,  ..., -0.7222,  0.1920,  0.6201]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 85,  55,  78,  98,  84,  49,  69,  77,  39,  60,  57,  88,  89,  48,
          8,  80,  87,  75,  92,  57,   1,  55,  78,  20,  84,  72,  20,  10,
          2,  67,  10,  64,  45,  68,  49,  50,  56,  66,  77,  33,  83,  79,
         78,  91, 100,  21,  94,  96,  95,  82], device='cuda:0')
tensor([[-0.7891, -0.6396, -0.6279,  ...,  0.0257,  0.2393, -0.2803],
        [-0.7251, -0.2952, -0.4697,  ..., -0.3274,  0.1142, -0.1120],
        [-0.9331, -0.8691, -0.5894,  ..., -0.2487,  0.34

tensor([[-0.7812,  0.6245, -0.1744,  ..., -0.2163,  0.4104, -0.1375],
        [-0.8335,  0.0017,  0.2133,  ..., -0.4631, -0.2615,  0.0944],
        [-1.1025, -0.1985, -0.0539,  ..., -0.2964,  0.3176,  0.0074],
        ...,
        [-0.6890, -0.6523, -0.2148,  ..., -0.2477,  0.3801, -0.4199],
        [-0.9209, -0.6924, -0.5586,  ..., -0.0765,  0.5752,  0.9399],
        [-0.4587, -0.6221, -0.8496,  ..., -0.0326,  0.6763, -0.7480]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 31,  52,  22,  24,  44,  71,  77,  75,   5,  78,  64,  94,  49, 101,
        101,  43,  98,  62,  87,  32,  37,  72,  84,  90,  42,  28,  73,  92,
         98,  96,  73,  20,  68,  12,   7,  38,  90,  89,  80,  93,   6,  85,
         85,  83,  31,  77,  19,  42,  19,  36], device='cuda:0')
tensor([[-0.6812, -0.8071, -0.5796,  ..., -0.0517,  0.7144, -0.0945],
        [-0.7090, -0.6709, -0.4844,  ..., -0.2500, -0.1519,  0.2018],
        [-0.5088, -0.4285, -0.6997,  ...,  0.1394,  0.37

tensor([[-0.3994, -0.4729, -0.4167,  ...,  0.3000, -0.0989, -0.4561],
        [-0.7109, -0.0987,  0.2466,  ..., -0.3123, -0.1030, -0.1525],
        [-0.7134, -0.3953, -0.5728,  ..., -0.3752, -0.2537,  0.1260],
        ...,
        [-0.9634, -0.6108, -0.4961,  ...,  0.2915,  0.3164,  0.5312],
        [-0.6196,  0.2289, -0.2502,  ...,  0.1006,  0.3743,  0.1694],
        [-0.5405, -0.6011, -0.5723,  ...,  0.4160,  0.1646, -0.4177]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 70,  52,  74,  77,  74,  90,  38,  74,  48, 101,  41,  73,  33,  31,
         46,  82,  18,   7,  39,  71,  73,  96,  90,  92,  48,  89,  26,  28,
         54,  38,  55,  33,  73,  70,  16,  13,  83,  85,  77,  49,  19,  65,
         94,  78,  86,  27,   7,  60,  94,  90], device='cuda:0')
tensor([[-0.7798, -0.3853,  0.2213,  ..., -0.0730, -0.4333,  0.1820],
        [-0.5215, -0.1913, -0.5273,  ...,  0.3467, -0.0781, -0.3669],
        [-0.7695,  0.0589, -0.2717,  ...,  0.2695,  0.33

tensor([[-0.8208, -0.8960, -0.5566,  ..., -0.6001,  0.4033,  0.1361],
        [-0.6763, -0.5444, -0.5264,  ..., -0.1486, -0.0180, -0.3372],
        [-0.6133, -0.6655, -0.7192,  ...,  0.2371, -0.1104,  0.3665],
        ...,
        [-0.9380,  0.5083, -0.2476,  ..., -0.2365,  0.1970, -0.2363],
        [-1.0352, -0.6235, -0.6816,  ...,  0.3689,  0.4602,  0.0743],
        [-0.9453, -0.2771, -0.1171,  ...,  0.0013,  0.5527,  0.1199]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 82,  66,  37,  77,  78,  54,  20,  73,  90,  49,  91,  18,  33,  54,
         43,  39,  61,  32,  38,  80,  60,  96,  72,  24,  19,  82, 101,  97,
          3,  16,  77,  14,  74,  71,  54,  83,   6,  60,  89,  83,  23,  83,
         57,  78,  81,   6,  44,  31,  40,  22], device='cuda:0')
tensor([[-0.8853, -0.2563, -0.6484,  ..., -0.1473,  0.8716, -0.5234],
        [-0.8638, -0.6992,  0.0139,  ..., -0.1036, -0.4370, -0.0205],
        [-0.5215, -0.2615, -0.0938,  ..., -0.2399,  0.08

tensor([[-0.7295,  0.1809, -0.2693,  ..., -0.0251,  0.1654,  0.2435],
        [-0.8027, -0.5132, -0.4722,  ..., -0.4304,  0.0228, -0.0018],
        [-0.8569, -0.4092, -0.4561,  ..., -0.4097, -0.2141,  0.1901],
        ...,
        [-0.6646, -0.0260,  0.0057,  ..., -0.4746, -0.1851, -0.2345],
        [-0.6426,  0.0466, -0.4670,  ...,  0.1003,  0.2891,  0.2617],
        [-0.7866, -0.0477, -0.4568,  ...,  0.0984,  0.5273, -0.0665]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([77, 74, 73, 54, 80, 54, 68, 50, 98, 89, 12, 45, 47, 16, 14, 82, 61, 93,
        48, 86, 70, 36, 58, 10, 43, 98, 89, 92, 83, 89, 97, 75, 28, 49, 76, 43,
        87, 18, 75, 73, 42, 89,  2, 85, 38, 80, 77, 46, 77, 77],
       device='cuda:0')
tensor([[-0.4836, -0.1973, -0.2520,  ..., -0.3936,  0.5156,  0.0061],
        [-0.6543, -0.5488, -0.6865,  ..., -0.1289, -0.6162, -0.3887],
        [-0.6831, -0.8066, -0.6138,  ..., -0.2825,  0.9551, -0.3533],
        ...,
        [-0.5610, -0.775

tensor([[-0.7520,  0.4434, -0.0206,  ..., -0.1218,  0.1449,  0.2460],
        [-0.7100, -0.8237, -0.5542,  ...,  0.1456,  0.4851,  0.0232],
        [-0.3442, -0.5752, -0.3245,  ...,  0.1676, -0.0345, -0.6060],
        ...,
        [-0.3875,  0.1953, -0.5498,  ..., -0.3264,  0.3257, -0.4590],
        [-0.2800, -0.8247, -0.4167,  ...,  0.1024,  0.4817, -0.3650],
        [-0.6973, -0.6440, -0.3181,  ...,  0.1384,  0.4062, -0.1011]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.6743,  0.0143, -0.2598,  ..., -0.1063,  0.4128,  0.4236],
        [-0.3647, -0.6079, -0.1805,  ...,  0.3489,  0.0363, -0.4868],
        [-1.0723, -0.7656, -0.8867,  ...,  0.1245,  0.58

tensor([[-0.8569, -0.7222, -0.7559,  ...,  0.1213,  0.3635,  0.1902],
        [-0.6250, -0.6377, -0.3997,  ..., -0.2610,  0.3330,  1.5869],
        [-0.7334,  0.0844, -0.0111,  ..., -0.5698, -0.3921, -0.3416],
        ...,
        [-0.6519,  0.3142,  0.4941,  ..., -0.0300, -0.6094, -0.2357],
        [-0.6597, -0.9404, -0.6836,  ...,  0.0717,  0.5571, -0.1940],
        [-0.3137, -0.3896, -0.5596,  ...,  0.0519, -0.0181, -0.5361]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  3, 101,  46,  68,  64,   5,   7,  14,  77,  43,  73,  14,  70,  99,
         98,  56,  16,  67,   5,  89, 101,  89,  38,  38,  59,  73,   4,  79,
         38,  74,   2,  36,  24,  11,  45,  74,  73,  49,  77,  15,  59,  11,
         80,  37,  77,  73,  97,   6,  72,  45], device='cuda:0')
tensor([[-0.6641, -0.0070, -0.3591,  ..., -0.0922,  0.3240,  0.0342],
        [-0.8735, -0.4326, -0.4700,  ...,  0.1177,  0.0629,  0.4214],
        [-0.5347, -1.0771, -0.5327,  ...,  0.4114,  0.92

tensor([[-0.9976,  0.2469, -0.7344,  ...,  0.3755, -0.0767,  0.1910],
        [-0.6021, -0.4343, -0.6558,  ..., -0.1614,  0.6841,  0.0164],
        [-0.7383, -0.1425, -0.3000,  ...,  0.0360,  0.3242,  0.2147],
        ...,
        [-0.9736, -0.2581,  0.2162,  ..., -0.2069, -0.0526, -0.0902],
        [-0.9561, -0.2593,  0.3657,  ..., -0.3152, -0.1194,  0.1936],
        [-0.6958, -0.1747, -0.1282,  ...,  0.4370,  0.3235,  0.4668]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 58,  32,  23,  93,  72,  19, 100,  52,  75,  26,  73,  57,  56,  51,
         33,  34,  74,  97, 100,   5,  10,  43,  54,  97,  69,  30,  49,  43,
         82,  95,  97,  34,  81,  92,  76,  39,  78,  34,  11,  41,  16,  18,
         41,  43,  44,   5,  80,  47,  52,  77], device='cuda:0')
tensor([[-0.8291, -0.1088,  0.4849,  ..., -0.4319, -0.0902,  0.1792],
        [-0.7769, -0.6973, -0.3821,  ..., -0.1731,  0.0948,  0.5767],
        [-0.4558, -0.5776, -0.5024,  ...,  0.1826,  1.47

tensor([[-0.6299, -0.3777, -0.3621,  ...,  0.5410,  0.6094, -0.3279],
        [-0.6304, -0.4468, -0.2686,  ..., -0.2477, -0.4417,  0.2473],
        [-0.5303, -0.3567, -0.2500,  ...,  0.1503, -0.1849,  0.3638],
        ...,
        [-0.7310,  0.0306, -0.3259,  ...,  0.2156,  0.2097,  0.3037],
        [-0.8784, -0.7041, -0.8955,  ..., -0.3640,  0.5229,  0.1036],
        [-0.5205, -0.8633, -0.4756,  ...,  0.1389,  1.0518, -0.1410]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 23,  73,  77,  52,  90,  89,  73,  26,  43,  59,  95,  91,  43,   7,
         11,  89,  58,  63,  75, 100,  76,   9,  73,  81,  68,   4,  44,  77,
         34,  92, 101,  71,  89,  48,  11,  90,  54,  73,  81,  79,  74,  40,
         46,  35,  48,  43,  45,  77,  55,  84], device='cuda:0')
tensor([[-0.6519,  0.1043,  0.2815,  ...,  0.0285, -0.2224, -0.5449],
        [-0.6597, -0.5815, -0.5547,  ...,  0.0273,  0.6606, -0.2632],
        [-0.5317, -1.0625, -0.4119,  ...,  0.2556,  0.94

tensor([[-0.7095, -0.9712, -0.7476,  ...,  0.1304,  0.9077, -0.1970],
        [-0.7544, -0.6509, -0.5366,  ..., -0.5161,  0.3235,  0.4624],
        [-0.8267, -0.5664, -0.6597,  ..., -0.1642, -0.0663,  0.2017],
        ...,
        [-0.6470, -0.3696, -0.7197,  ...,  0.2383,  0.0115, -0.3196],
        [-0.7876, -0.4294, -0.2174,  ..., -0.4373,  0.0737,  0.1337],
        [-0.4163, -0.6123, -0.4707,  ...,  0.3787, -0.2433, -0.5649]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([97, 32, 78, 70, 90, 65, 35, 40,  6,  2, 56, 56, 74, 79, 73, 75, 89, 49,
        26, 48, 96, 82, 23, 33, 90, 55, 84, 97, 53,  0, 90, 95, 75, 57, 43, 20,
        71, 15, 67, 78, 88, 85, 28,  0, 82, 18, 64, 96, 74, 61],
       device='cuda:0')
tensor([[-1.1592, -0.0428,  0.0043,  ..., -0.4775,  0.4126,  0.0894],
        [-0.5132, -0.1092, -0.6436,  ...,  0.3696,  0.5571, -0.4639],
        [-0.6748, -0.4705, -0.9692,  ...,  0.0235,  0.2888, -0.3191],
        ...,
        [-0.8950, -0.906

0it [00:00, ?it/s]

Found best model
Training: Epoch 10 || Loss:   2.775 || Accuracy:  51.18%
tensor([[-0.7490,  0.5127,  0.0430,  ..., -0.1069,  0.1246,  0.2788],
        [-0.6953, -0.8486, -0.5537,  ...,  0.1906,  0.5239,  0.0501],
        [-0.2986, -0.5605, -0.3071,  ...,  0.2090, -0.0266, -0.6548],
        ...,
        [-0.3438,  0.2537, -0.5229,  ..., -0.3411,  0.3608, -0.4839],
        [-0.2325, -0.8379, -0.4009,  ...,  0.1244,  0.5356, -0.3713],
        [-0.6812, -0.6572, -0.3115,  ...,  0.1686,  0.4285, -0.0821]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.6572,  0.0522, -0.2241,  ..., -0.1024,  0.4412,  0.4590],
        [-0.3198, -0.5962, -0.1549,  ...,  0.3809,  

tensor([[-0.7324, -0.6865, -0.6226,  ..., -0.0601,  0.0894, -0.2297],
        [-0.5264, -0.8135, -0.5088,  ..., -0.1493, -0.0722, -0.6675],
        [-0.4573, -0.5522, -0.8467,  ..., -0.3359,  1.4844, -0.2115],
        ...,
        [-0.5859, -0.0754, -0.3420,  ...,  0.3965,  0.2817, -0.3411],
        [-0.5039, -0.4397, -0.6035,  ...,  1.3027, -0.3813, -0.2252],
        [-0.9653, -0.6016, -0.3320,  ..., -0.1705, -0.1503, -0.0555]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,  50, 100,  49,  45, 100,   7,  40,  99,  18,  44,  97,  24,  43,
         26,  84,  89,  54,  78,  67,  26,  84,  64,  72,  34,  41,  43,  33,
         84,  56,  78,  94,  72,  18,  40,  73,  89,  71,   7,  43,  75,  88,
         55,  54,  40,  78,  73,  18,  99,  56], device='cuda:0')
tensor([[-0.7148, -0.1927, -0.3137,  ..., -0.4348, -0.8564, -0.2424],
        [-0.6914, -0.6016, -0.4805,  ..., -0.4617,  0.1569,  0.2013],
        [-0.5615, -0.0318,  0.1729,  ..., -0.0742, -0.17

tensor([[-0.7285,  0.6646, -0.0317,  ...,  0.0134,  0.2573, -0.0324],
        [-0.5601, -0.9277, -0.5415,  ...,  0.3752,  0.1472, -0.0899],
        [-0.6494,  0.1265, -0.1744,  ...,  0.1738,  0.1526,  0.0119],
        ...,
        [-0.5781, -0.6245, -0.7432,  ...,  0.3018,  0.7397, -0.4573],
        [-0.7124, -0.4094, -0.7441,  ..., -0.1119,  0.0784, -0.1862],
        [-0.7925, -1.0225, -0.5317,  ...,  0.2961,  0.6118, -0.4060]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 24,  98,  77,  72,  82,  92,  64,  67,  24,  40,  83,  90,  98,   1,
          3,  85,  68,  53,  15,  93,  29,  26,   2,  55, 100, 100,  91,  95,
         52,  34,  23,  64,  41,  42,  98,  68,  40,  86,  33,  16,  75,  83,
         99,  54,  74,   6,  77,  97,  93,  90], device='cuda:0')
tensor([[-1.0059,  0.2922, -0.7290,  ...,  0.4075, -0.1004,  0.2330],
        [-0.5928, -0.4343, -0.6729,  ..., -0.1700,  0.7275,  0.0537],
        [-0.7183, -0.1191, -0.2878,  ...,  0.0348,  0.35

tensor([[-0.7358,  0.3887,  0.3020,  ..., -0.1571,  0.3496, -0.2308],
        [-0.4836, -0.6279, -0.3643,  ...,  0.0073, -0.6499, -0.0439],
        [-0.7856, -0.0012, -0.5195,  ..., -0.5381,  0.3716,  0.1691],
        ...,
        [-0.6113, -0.6445, -0.4768,  ..., -0.1213, -0.0618, -0.2236],
        [-0.8101, -0.6387, -0.4668,  ..., -0.1281,  0.0511, -0.0655],
        [-0.6421, -0.7388, -0.2206,  ..., -0.2781,  0.4141,  0.2883]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([48, 54, 33, 79, 13, 93, 17, 43,  8, 60, 51, 74, 98, 83, 86, 93, 18, 60,
        67, 96, 46, 34, 96, 77, 78, 46,  8, 75, 72, 16, 90, 24, 24, 87, 70, 59,
        56, 50, 73, 51, 83, 82, 39, 80,  2, 59, 74, 74, 74, 26],
       device='cuda:0')
tensor([[-0.7402, -0.5220, -0.0536,  ..., -0.0562, -0.6562,  0.1959],
        [-0.6484, -0.7959, -0.4653,  ..., -0.1156,  1.4814,  0.2375],
        [-0.6548, -0.2673, -0.6050,  ...,  0.2260,  0.4675, -0.6401],
        ...,
        [-0.8887, -0.833

tensor([[-0.9966, -0.4275, -0.4795,  ...,  0.1149,  0.3955, -0.7310],
        [-0.4912, -0.7095, -0.0431,  ..., -0.1198,  0.3169, -0.5107],
        [-0.6191, -1.0225, -0.2688,  ...,  0.2527,  0.8979, -0.1102],
        ...,
        [-0.6719, -0.8892, -0.4348,  ..., -0.0262,  0.2283, -0.2883],
        [-0.7212, -0.3342, -0.5107,  ..., -0.0887, -0.2717,  0.2698],
        [-0.4873, -0.0847, -0.5469,  ...,  0.4094, -0.1055, -0.1214]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 49,  64,  84,  54,  59,  80,  88,  37,  84,  96,   5,  98,  73,   5,
         61,  15,  67,   9,  57,  49,  44,  75,  96,  76,   7,  49, 100,  88,
         50,  71,  43,  31, 101,   9,  97,   6,  21,  67,  22,   9,  48,  97,
         84,  49,   6,  84,  81,  90,  39,  18], device='cuda:0')
tensor([[-0.5283, -1.2051, -0.4294,  ...,  0.4636,  0.6382,  0.2437],
        [-0.9858, -0.2117,  0.4138,  ..., -0.5425,  0.0543,  0.1626],
        [-0.6294,  0.0312, -0.4600,  ..., -0.1320,  0.72

tensor([[-0.3350, -0.6855, -0.6782,  ..., -0.3008,  1.0449, -0.3252],
        [-0.2930, -0.9019, -0.1284,  ..., -0.1549,  0.8784,  0.1026],
        [-0.8726,  0.0641,  0.3887,  ...,  0.1913, -0.1646, -0.3367],
        ...,
        [-0.6855, -1.0967, -0.3359,  ..., -0.0376,  0.0138,  0.3201],
        [-0.5068, -0.7368, -0.1731,  ..., -0.1791, -0.2413, -0.7710],
        [-0.4656, -0.1436, -0.4426,  ..., -0.0407,  0.1093, -0.5332]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 28,  72,  52,  40,  47,  23,  47,  33,  38,  45,  12,  86,  80,  97,
         96,  48,  70,  78,  38,   6,  96,  30,  90,  74,  96,   5,  58,  75,
        101,  40,  38,  77,  44,  45,  64,  30,  46,  33,  49,  36,  54,  90,
         64,  51,  43,  31, 101,   4,  64,  66], device='cuda:0')
tensor([[-0.4824, -0.6597, -0.7485,  ..., -0.1412, -0.1464,  0.2524],
        [-0.4131, -0.3494, -0.2087,  ..., -0.1436,  0.1372, -0.4036],
        [-0.5967, -0.8208, -0.5024,  ...,  0.1349,  0.41

tensor([[-0.5117, -0.7275, -0.3301,  ...,  0.4424,  0.1141, -0.4946],
        [-0.4495, -0.5850, -0.6201,  ...,  0.5234,  0.1060, -0.5332],
        [-0.6426,  0.1007, -0.6587,  ..., -0.1373,  0.5049, -0.0281],
        ...,
        [-0.7544, -0.5518, -0.1167,  ..., -0.0690,  0.2007,  0.4807],
        [-0.5190, -0.4800, -0.6328,  ...,  1.3008,  0.0154, -0.3101],
        [-0.1460, -0.8267, -0.6978,  ...,  0.2065,  1.2363, -0.8403]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 51,  96,  94,  49,  20,  89,  55,  98,  78,  49,  32,  98,  64,  74,
         87,   8,  17,  77,  49,   5,  75,  24,  38,  89,  89,   2,  11,  76,
         42,  88,  64,  28,  59,  98,  89,   3,  84,  98,  87,  43,  33,  48,
         70,  40,  33, 100,  78,   4,  99,  53], device='cuda:0')
tensor([[-0.6934, -0.9429, -0.5181,  ...,  0.0197,  0.4548,  0.2610],
        [-0.6406, -0.0751,  0.7090,  ..., -0.1271, -0.1993,  0.0770],
        [-0.5742, -1.2129, -0.6025,  ...,  0.0831,  1.01

tensor([[-0.9121, -0.7573, -0.6650,  ..., -0.4150,  0.0473,  0.6187],
        [-0.3640,  0.0394, -0.4763,  ...,  0.2812,  0.2983, -0.7441],
        [-0.6646, -0.9653, -0.4236,  ...,  0.3469,  1.1260, -0.1720],
        ...,
        [-0.5239, -0.9297, -0.2172,  ...,  0.3118,  0.6211, -0.2256],
        [-0.4819, -0.7480, -0.2666,  ...,  0.8442,  0.4622, -0.2681],
        [-1.0908, -0.1444, -0.0494,  ..., -0.0075,  0.3167,  1.3877]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 73,  96,  84,  73,  49,  59,  59,  40,  71,  81,  91,  80,  40,   6,
         74,  25,  46,  91,  56,  98,  21,  45,  69,  39,  44,   6,  32,  81,
         38,  96, 100,  89,  18,  78,  99,  77,   8,   3,  74,  87,  59,  23,
         78,  74,  87,   9,   3,  49,  99,  95], device='cuda:0')
tensor([[-0.5854,  0.3752, -0.0263,  ...,  0.0661, -0.0454,  0.4219],
        [-0.6743,  0.5850,  0.5122,  ..., -0.1823,  0.0908, -0.3406],
        [-0.7393, -0.7559, -0.1941,  ..., -0.1172,  1.02

tensor([[-0.6899, -0.6904, -0.6616,  ...,  0.0466,  0.9771,  0.3425],
        [-0.3406, -0.5127, -0.1879,  ..., -0.1562,  1.2266, -0.3105],
        [-0.3027, -1.0547, -0.4634,  ..., -0.0175,  1.1211, -0.0681],
        ...,
        [-0.3906, -0.3682, -0.4114,  ...,  0.6221, -0.2651, -0.2437],
        [-0.8047, -0.8228, -0.4307,  ..., -0.0581,  0.2622, -0.3760],
        [-0.4302, -0.4309, -0.6729,  ...,  0.1051, -0.0757, -0.3716]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 32,  49,  72,  92,  37,  92,  20,  63,  74,  86,  71,   5,  41,  62,
          5,  82,   7,  29,  53, 100,  90,  49,  20,  17,  87,  78,  88,  24,
         44,  84,  56,  83,  42,  26,  47,  57,  43,  62,  18,  56,  19,   6,
         51,  63,  98,  85,   6,  37,  43,  89], device='cuda:0')
tensor([[-0.6309, -0.2720, -0.2339,  ..., -0.1890, -0.0797, -0.3618],
        [-0.6694, -0.9585, -0.8384,  ...,  0.5552,  0.4358,  0.2107],
        [-0.5327, -0.7715, -0.9307,  ...,  0.0093, -0.09

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 22,  77,  49,  53,  12,  22,  14,  12,  49,  53,  52,  89,  83,  86,
         62,  40,   8,  41,  73,  38,  68,  70,   4,  64,  32,  68,  86,  82,
         11,  66,  13,  59,  43,  15,  61,  90,  61,  55,  58,  27, 100,  90,
         72,  96,  89, 100,  14,  71,  26,  64], device='cuda:0')
tensor([[-0.7671, -0.9385, -0.6201,  ...,  0.6221,  0.1154,  1.0322],
        [-0.4661, -0.6899, -0.3235,  ..., -0.0719, -0.2161, -0.7705],
        [-0.4346, -0.2563, -0.2905,  ...,  0.5142, -0.0983, -0.5181],
        ...,
        [-0.6108, -0.0129, -0.4517,  ...,  0.0925, -0.6255, -0.6802],
        [-0.8604, -0.2169, -0.3416,  ...,  0.1241, -0.1619,  0.8662],
        [-0.6509,  0.1577, -0.1102,  ...,  0.1109,  0.1895,  0.4033]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 81,  64,  10,  58,  11,  14,  26,  19,  37,  47,  49,  61,  24,  36,
         40,  52,  40,  71,  38,  77,  77,  35, 

tensor([[-0.4236, -0.7183, -0.6675,  ...,  0.0945, -0.2957,  0.7817],
        [-0.9878, -1.0068, -0.4836,  ..., -0.2346,  0.0894,  0.7412],
        [-0.6748, -1.0176, -0.8252,  ...,  0.2983,  1.2988,  0.4294],
        ...,
        [-0.3042, -0.4851, -0.6104,  ...,  0.9023, -0.3816, -0.5176],
        [-0.4929, -0.7109, -0.5737,  ..., -0.2881,  0.5557, -0.4712],
        [-0.3848, -0.0364, -0.7100,  ...,  0.1835,  1.0000, -0.7925]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 93,  40,  97,  32,   2,  88,  25,  73,  25,  36,  23,  81,  82,  63,
        100,  75,  18,  68,  78,  53,   6,  66,  76,  77,  81,   3,  77,  78,
         53,  96,  38,  79,  41,   9,  34,   8,  81,  37,  63,  74,  98,  43,
          4,  86,  48,  64,  63,  37,  50,  91], device='cuda:0')
tensor([[-0.7134, -0.6826, -0.6128,  ...,  0.0089,  0.0952, -0.2322],
        [-0.4414, -0.8257, -0.4937,  ..., -0.1278, -0.0657, -0.7339],
        [-0.4038, -0.5264, -0.8691,  ..., -0.3008,  1.68

tensor([[-0.3987, -0.3506, -0.4521,  ...,  0.0152, -0.5093, -0.4575],
        [-0.2292, -0.4980, -0.7139,  ...,  0.0688,  0.5303, -0.2654],
        [-0.4927, -0.7397, -0.4856,  ..., -0.0115,  0.4187, -0.3896],
        ...,
        [-0.4932, -0.4692, -0.6929,  ...,  0.4937,  0.5732, -0.0748],
        [-0.4233, -0.6133, -0.4299,  ...,  0.3301,  0.1741, -0.0354],
        [-0.4680, -1.0488, -0.7051,  ..., -0.0036, -0.0947,  0.5464]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([39, 83, 49, 90, 83, 18, 37, 41, 64,  6, 77, 73, 91, 57, 26, 73, 99, 43,
        16, 98, 64, 18, 11, 82, 56, 22, 89, 74, 95, 91, 17, 94, 77, 43, 84, 74,
        70,  8, 90, 39, 59, 87, 95,  3, 53, 78, 86, 14, 98, 93],
       device='cuda:0')
tensor([[-0.7876, -0.1130,  0.2949,  ..., -0.2177,  0.0126,  0.0848],
        [-0.4387, -0.4858, -0.5586,  ...,  1.0449,  0.3193, -0.5132],
        [-0.7988, -0.8062, -0.8311,  ..., -0.2397,  0.0399,  0.4929],
        ...,
        [-0.5889, -0.782

tensor([[-0.4839, -0.5244, -0.5571,  ...,  1.6367,  0.0804, -0.1099],
        [-0.5693, -0.9663, -0.5742,  ..., -0.2661,  0.7666, -0.4309],
        [-0.7969, -0.4321, -0.6768,  ...,  0.0622,  0.6060,  0.9600],
        ...,
        [-0.5493,  0.5132,  0.3865,  ...,  0.1917, -0.1311, -0.3074],
        [-0.3484, -0.5068, -0.4438,  ...,  0.3342, -0.1989, -0.7358],
        [-0.8833, -0.9785, -0.4268,  ...,  0.1909,  0.2480, -0.1252]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  50,  95,   5,  28,  96,  28,  85,  40,  44,  12, 100,  24,  66,
         98,  54,  34,  89,  21,  43,  78,   4,  73,  97,  33,  51,  62,  21,
         94,  40,  24,   1,  85,  17,  96,  97,  14,   2,  82,  92,  96,  12,
         77,   0,  35,  67,  76,   6,  61,  56], device='cuda:0')
tensor([[-0.6782,  0.5210,  0.3962,  ..., -0.1166,  0.3694, -0.2386],
        [-0.4575, -0.6338, -0.3030,  ...,  0.0705, -0.7212, -0.0139],
        [-0.7432,  0.0473, -0.5283,  ..., -0.5825,  0.39

tensor([[-0.1444, -0.8218, -0.3159,  ...,  0.5400,  0.7832, -0.0903],
        [-0.2556, -0.7783, -0.0608,  ...,  0.1829,  0.0950, -0.2981],
        [-0.3062, -0.0748,  0.1475,  ..., -0.1945,  0.0724, -0.2903],
        ...,
        [-0.4956, -0.7778, -0.3293,  ...,  0.3936, -0.1316, -0.3682],
        [-0.6382, -0.1097, -0.3679,  ...,  0.2786,  0.7051,  0.8921],
        [-0.5098, -0.6792, -0.7163,  ...,  0.3730,  1.1670,  0.0997]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 84,  12,  65,   8,  18,  22,  92,  18,  22,  66,  25,  39,  44,  45,
          7,  65,  51,  51,   9, 100,  89,  74,  98,  11,  48,  78,  89,   8,
         75,  74,  75,  90,  96,  64,  72,   3,  82,  83,  74,  64,   6,  72,
         94,   3,  33,  77,  41,  37,  95,  75], device='cuda:0')
tensor([[-0.5381, -0.4246, -0.1979,  ...,  0.0695,  0.4890,  0.2720],
        [-0.5850, -0.5732, -0.3777,  ...,  0.0735, -0.6523,  0.7627],
        [-0.7139,  0.3247,  0.6426,  ..., -0.1328,  0.12

tensor([[-2.6538e-01, -6.8799e-01, -6.7725e-01,  ..., -3.1592e-01,
          1.1924e+00, -3.3374e-01],
        [-2.1521e-01, -9.4141e-01, -9.4727e-02,  ..., -1.3513e-01,
          1.0176e+00,  1.4832e-01],
        [-8.6328e-01,  1.4197e-01,  5.4980e-01,  ...,  2.5781e-01,
         -1.9385e-01, -3.6597e-01],
        ...,
        [-6.6016e-01, -1.1250e+00, -2.8687e-01,  ...,  1.0414e-03,
         -5.8098e-03,  4.3091e-01],
        [-4.3555e-01, -7.3584e-01, -1.1493e-01,  ..., -1.4832e-01,
         -2.4658e-01, -8.5938e-01],
        [-3.9038e-01, -9.8022e-02, -3.9136e-01,  ...,  1.0042e-03,
          1.4880e-01, -5.5420e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 28,  72,  52,  40,  47,  23,  47,  33,  38,  45,  12,  86,  80,  97,
         96,  48,  70,  78,  38,   6,  96,  30,  90,  74,  96,   5,  58,  75,
        101,  40,  38,  77,  44,  45,  64,  30,  46,  33,  49,  36,  54,  90,
         64,  51,  43,  31, 101,   4,  64,  66], device='cuda:0'

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([81, 41, 92, 80,  6, 83,  1, 25, 25,  4, 83, 26, 96, 49, 74, 74, 85, 67,
        74, 70, 83, 52, 67, 89, 59, 37, 34, 29, 69, 62, 12, 15, 43, 67, 60, 64,
        39,  5, 77, 28, 54, 36, 73, 14, 67, 78,  8, 86, 18, 77],
       device='cuda:0')
tensor([[-0.4446,  0.3748, -0.4475,  ...,  0.7866, -0.1407, -0.2983],
        [-0.9736, -0.0186,  0.6616,  ..., -0.4937, -0.3560,  0.4617],
        [-0.5205, -0.7090, -0.6396,  ...,  0.0745,  0.6021, -0.4626],
        ...,
        [-0.8374, -1.2578, -0.7383,  ..., -0.5879,  0.5869,  1.0205],
        [-0.8765, -0.9180, -0.7783,  ...,  0.3250,  0.2389,  0.1431],
        [-1.1875,  0.1852, -0.1729,  ..., -0.2710,  0.4778,  0.0473]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([18, 71, 39, 88, 47, 39, 43, 77, 38, 78, 74, 35,  4, 65, 50, 68, 65, 82,
        93,  5, 62, 64, 47,  4, 65, 43, 47, 86, 62, 93, 66, 41, 49, 49, 78, 70,
        77, 73,  

tensor([[-9.1602e-01, -7.4219e-01, -6.3770e-01,  ..., -4.1626e-01,
          5.0449e-04,  7.4561e-01],
        [-3.0078e-01,  1.1810e-01, -4.4385e-01,  ...,  3.7036e-01,
          3.2983e-01, -8.1543e-01],
        [-5.7471e-01, -9.9805e-01, -4.0332e-01,  ...,  4.1455e-01,
          1.2676e+00, -1.2952e-01],
        ...,
        [-4.5972e-01, -9.5459e-01, -1.5857e-01,  ...,  3.8208e-01,
          6.6650e-01, -2.3999e-01],
        [-3.8989e-01, -7.3535e-01, -2.0557e-01,  ...,  1.0273e+00,
          5.3809e-01, -2.7783e-01],
        [-1.1064e+00, -9.3567e-02,  3.6652e-02,  ...,  1.4221e-02,
          2.9639e-01,  1.6025e+00]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 73,  96,  84,  73,  49,  59,  59,  40,  71,  81,  91,  80,  40,   6,
         74,  25,  46,  91,  56,  98,  21,  45,  69,  39,  44,   6,  32,  81,
         38,  96, 100,  89,  18,  78,  99,  77,   8,   3,  74,  87,  59,  23,
         78,  74,  87,   9,   3,  49,  99,  95], device='cuda:0'

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 6, 34, 18, 36, 88, 75, 67, 38, 72, 95, 33, 38, 54, 24, 43, 80, 92, 64,
        23, 20, 85, 97, 12, 33, 33, 29, 47, 49, 63, 24, 44,  7, 56, 77, 96,  7,
        18, 71, 90, 80, 80, 50, 49, 78, 80, 48, 50, 96, 59, 55],
       device='cuda:0')
tensor([[-0.4712, -0.6997, -0.3582,  ...,  0.2803,  0.8252, -0.2959],
        [-1.1592, -0.0087,  0.2272,  ..., -0.2258,  0.3984, -0.1863],
        [-0.8506, -0.1007,  0.7305,  ..., -0.4214, -0.3760,  0.0580],
        ...,
        [-0.7310, -1.0527, -0.4497,  ..., -0.1560,  0.1030,  0.5161],
        [-0.6685, -0.6235, -0.0474,  ..., -0.0622, -0.5386,  0.6309],
        [-0.5786, -0.5098, -0.3440,  ...,  0.2883, -0.0599, -0.5557]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([36, 22, 71, 68,  6, 64, 33, 49, 61, 78, 44, 59, 76, 92, 73, 74, 67, 72,
        54, 90, 40, 59, 84, 97, 48, 13, 69,  7, 43, 75, 92, 56, 35, 37, 36, 83,
        97, 90, 2

tensor([[-0.7617, -0.9429, -0.6021,  ...,  0.7178,  0.1163,  1.1777],
        [-0.3762, -0.6782, -0.2842,  ..., -0.0166, -0.2252, -0.8516],
        [-0.3733, -0.1813, -0.2581,  ...,  0.6333, -0.1343, -0.5698],
        ...,
        [-0.5771,  0.0396, -0.4358,  ...,  0.1743, -0.6543, -0.7373],
        [-0.8652, -0.1963, -0.3186,  ...,  0.1445, -0.2004,  1.0293],
        [-0.6001,  0.2568, -0.0176,  ...,  0.1471,  0.2001,  0.4741]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 81,  64,  10,  58,  11,  14,  26,  19,  37,  47,  49,  61,  24,  36,
         40,  52,  40,  71,  38,  77,  77,  35,  43,  20,  93,  82,  73,  98,
         22,  49,  90,  28,  64,  24,  22,   4,  94,  39,  71,  51,  72,  93,
         55,  50, 101,  59,  39,  39,  79,  77], device='cuda:0')
tensor([[-0.0930, -0.3503, -0.1643,  ...,  0.4319, -0.3940, -1.0557],
        [-0.3672,  0.1993,  0.2450,  ..., -0.3064, -0.2744, -0.3677],
        [-0.4895, -0.7705, -0.5835,  ...,  0.5132,  0.14

tensor([[-8.2764e-01, -6.6504e-01, -6.6602e-01,  ..., -3.3887e-01,
         -4.1260e-01,  8.9404e-01],
        [-7.9297e-01, -7.7490e-01, -6.7139e-01,  ..., -1.8542e-01,
          6.0669e-02,  5.9229e-01],
        [-3.7085e-01, -4.6655e-01, -3.9673e-01,  ...,  3.0078e-01,
          8.2568e-01, -1.1357e+00],
        ...,
        [-5.3027e-01, -7.5049e-01, -3.8379e-01,  ...,  2.6953e-01,
          1.1182e+00,  3.6499e-01],
        [-7.1631e-01, -1.2617e+00, -1.0498e+00,  ..., -3.0589e-04,
          2.2021e-01, -3.3855e-03],
        [-3.0444e-01,  8.7036e-02,  4.7266e-01,  ..., -3.6328e-01,
          3.2642e-01, -1.9861e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 88,  69,  70,  28,  89,   0,  29,  83,  96,  64,  27,  53,  29, 100,
         94,  12,  49,  88,  56,  49,   6,  26,  45,  20,  47,  89,  64,  97,
         90,  34,  73,  41,  49,  40, 101,  56,  35,  17,  58,  14,  96,  59,
         32,   6,  98,  54,  49,  32,  15,  65], device='cuda:0'

tensor([[-0.3577, -0.3206, -0.4265,  ...,  0.0750, -0.5410, -0.4734],
        [-0.1306, -0.4714, -0.6909,  ...,  0.1180,  0.5820, -0.2583],
        [-0.4163, -0.7427, -0.4800,  ...,  0.0453,  0.4719, -0.4075],
        ...,
        [-0.4368, -0.4553, -0.7139,  ...,  0.5645,  0.6113, -0.0428],
        [-0.3757, -0.6074, -0.3911,  ...,  0.4163,  0.1907, -0.0202],
        [-0.4150, -1.0928, -0.7104,  ...,  0.0327, -0.0916,  0.6626]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([39, 83, 49, 90, 83, 18, 37, 41, 64,  6, 77, 73, 91, 57, 26, 73, 99, 43,
        16, 98, 64, 18, 11, 82, 56, 22, 89, 74, 95, 91, 17, 94, 77, 43, 84, 74,
        70,  8, 90, 39, 59, 87, 95,  3, 53, 78, 86, 14, 98, 93],
       device='cuda:0')
tensor([[-0.7617, -0.0628,  0.4399,  ..., -0.1921, -0.0197,  0.1476],
        [-0.3618, -0.4629, -0.5361,  ...,  1.2500,  0.3442, -0.5479],
        [-0.7803, -0.7949, -0.8188,  ..., -0.2229,  0.0271,  0.5815],
        ...,
        [-0.5259, -0.800

tensor([[-0.6436, -0.5210, -0.5024,  ..., -0.2561, -0.3083,  0.2177],
        [-0.4448, -0.1149, -0.2852,  ...,  0.4299, -0.7251, -0.0400],
        [-0.3579, -0.2935, -0.7295,  ...,  0.1028, -0.0521,  0.2455],
        ...,
        [-0.1638, -0.2211, -0.5010,  ...,  0.2231,  0.6309, -0.9170],
        [-0.7109, -0.0862,  0.3860,  ...,  0.1544, -0.7065,  0.0068],
        [-0.4736, -0.8535,  0.1248,  ...,  0.0569,  0.5234,  0.1301]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([78, 18, 35, 99, 96, 28, 82, 53, 40, 77, 49, 43, 64,  7, 23, 64, 86, 84,
        59, 30, 22, 53, 23, 83, 83, 74,  6, 57, 59, 50,  8, 96, 75, 43, 89, 84,
        29, 43, 97,  7, 80, 77, 59, 96, 65, 15, 26, 36, 57, 40],
       device='cuda:0')
tensor([[-0.4824, -0.4419, -0.3845,  ...,  0.3079,  0.1610, -0.4465],
        [-0.0331, -0.5952, -0.4446,  ...,  0.0017,  1.1240, -0.7407],
        [-0.7700, -0.1207,  0.7954,  ..., -0.2844, -0.5571,  0.7378],
        ...,
        [-0.5444, -0.908

tensor([[-0.0537, -0.8408, -0.2517,  ...,  0.6216,  0.8931, -0.0659],
        [-0.1559, -0.7983,  0.0283,  ...,  0.2323,  0.1399, -0.2881],
        [-0.2449, -0.0233,  0.2649,  ..., -0.1840,  0.0953, -0.3052],
        ...,
        [-0.4202, -0.7822, -0.2637,  ...,  0.4692, -0.1223, -0.3630],
        [-0.6060, -0.0479, -0.3005,  ...,  0.3384,  0.7563,  1.0361],
        [-0.4458, -0.6689, -0.7134,  ...,  0.4521,  1.3037,  0.1246]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 84,  12,  65,   8,  18,  22,  92,  18,  22,  66,  25,  39,  44,  45,
          7,  65,  51,  51,   9, 100,  89,  74,  98,  11,  48,  78,  89,   8,
         75,  74,  75,  90,  96,  64,  72,   3,  82,  83,  74,  64,   6,  72,
         94,   3,  33,  77,  41,  37,  95,  75], device='cuda:0')
tensor([[-0.4773, -0.4033, -0.1521,  ...,  0.1136,  0.5669,  0.3191],
        [-0.5464, -0.5415, -0.3362,  ...,  0.1002, -0.7158,  0.8765],
        [-0.6675,  0.4429,  0.7705,  ..., -0.1125,  0.14

tensor([[-0.7578, -0.5391, -0.1592,  ..., -0.1075,  0.3054,  0.2499],
        [-0.7515,  0.7231,  0.3718,  ...,  0.1637, -0.4785,  0.0151],
        [-0.9419, -1.1641, -0.6548,  ...,  0.7383,  0.6196,  1.0615],
        ...,
        [-0.7168, -0.9741, -0.4668,  ..., -0.0401,  0.8623, -0.2717],
        [-0.1918, -0.4226, -0.8735,  ...,  0.2084,  1.4307, -1.2803],
        [-0.2228, -0.7300, -0.0648,  ...,  0.4102, -0.2139, -0.9673]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,   6,  81, 101,  18,  11,  49,  75,  84,  24,  40,  27,  23,  87,
         87,  27,  74,  74,  34,  85,  97,  43,  98,  38,  54,  90,  79,  33,
         40,  62,  33,  96,  41,  13,  73,  72,  99,  97,  77,  74,   3,  54,
         87,  45,  80,  73,  77,  43,  36,  45], device='cuda:0')
tensor([[-1.0039, -0.8340, -0.1054,  ..., -0.1252,  0.4956,  2.1680],
        [-0.6138, -1.2393, -0.4316,  ...,  0.6226,  0.7285,  0.0112],
        [-0.1423, -1.0576, -0.5806,  ...,  0.2134,  0.97

tensor([[-0.8125, -0.4160, -0.3979,  ...,  0.5679, -0.0107,  0.6250],
        [-0.6968, -0.7324, -0.4297,  ..., -0.0911,  0.3091,  0.5630],
        [-0.7939, -0.2375, -0.5854,  ..., -0.0820,  0.8584,  0.5977],
        ...,
        [-0.6929, -0.9775, -0.3442,  ..., -0.3796,  1.3828,  0.0881],
        [-0.4060, -0.0309, -0.8174,  ...,  0.1171, -0.5171, -0.3311],
        [-0.3596,  0.3430,  0.0692,  ..., -0.0341,  0.1149, -0.1791]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([81, 41, 92, 80,  6, 83,  1, 25, 25,  4, 83, 26, 96, 49, 74, 74, 85, 67,
        74, 70, 83, 52, 67, 89, 59, 37, 34, 29, 69, 62, 12, 15, 43, 67, 60, 64,
        39,  5, 77, 28, 54, 36, 73, 14, 67, 78,  8, 86, 18, 77],
       device='cuda:0')
tensor([[-0.4009,  0.4739, -0.4185,  ...,  0.9453, -0.1592, -0.3208],
        [-0.9561,  0.0339,  0.8315,  ..., -0.5137, -0.4004,  0.5444],
        [-0.4670, -0.7109, -0.6440,  ...,  0.1489,  0.6650, -0.4678],
        ...,
        [-0.8306, -1.298

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([49,  5,  3, 77, 92, 33, 69, 49, 19, 49, 27, 71, 89, 16, 31, 96, 12, 65,
        66, 40, 62, 80, 86, 23, 96, 45, 41, 40, 90, 77, 88, 31, 95, 50, 80, 97,
        46, 67, 65, 19, 64, 86, 39,  0, 65, 89, 74, 45, 84, 95],
       device='cuda:0')
tensor([[-0.5547,  1.4824,  0.1802,  ..., -0.1093,  0.4998, -0.1455],
        [-0.8438,  0.3850,  0.7837,  ..., -0.4761, -0.5439,  0.2001],
        [-1.1260,  0.0140,  0.3662,  ..., -0.2388,  0.3391,  0.2437],
        ...,
        [-0.4392, -0.6201, -0.0029,  ..., -0.1283,  0.6748, -0.5244],
        [-1.0215, -0.5938, -0.4424,  ..., -0.0758,  0.5508,  1.8066],
        [-0.1584, -0.6348, -1.0859,  ...,  0.0963,  1.0986, -1.0684]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 31,  52,  22,  24,  44,  71,  77,  75,   5,  78,  64,  94,  49, 101,
        101,  43,  98,  62,  87,  32,  37,  72,  84,  90,  42,  28,  73,  92,
         98,  96,  73

tensor([[-0.3167,  1.0645,  0.6797,  ...,  0.0829, -0.5928, -0.3435],
        [-0.3765,  1.0391, -0.2389,  ...,  0.5596, -0.5039, -0.9570],
        [-0.3091,  0.1349, -0.2198,  ...,  0.8130, -0.0733, -0.4558],
        ...,
        [-0.0776, -0.3442, -0.3884,  ...,  0.3298,  0.5767, -0.4910],
        [-0.6621, -0.4294,  0.1945,  ...,  0.1897, -0.5200, -0.0075],
        [-0.1823, -0.5303, -0.5078,  ...,  0.2874,  0.3816,  0.1340]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 6, 34, 18, 36, 88, 75, 67, 38, 72, 95, 33, 38, 54, 24, 43, 80, 92, 64,
        23, 20, 85, 97, 12, 33, 33, 29, 47, 49, 63, 24, 44,  7, 56, 77, 96,  7,
        18, 71, 90, 80, 80, 50, 49, 78, 80, 48, 50, 96, 59, 55],
       device='cuda:0')
tensor([[-0.4131, -0.7095, -0.3384,  ...,  0.3342,  0.9028, -0.2776],
        [-1.1475,  0.0417,  0.3318,  ..., -0.2030,  0.4158, -0.1764],
        [-0.8354, -0.0370,  0.9058,  ..., -0.4265, -0.4241,  0.0900],
        ...,
        [-0.7100, -1.081

tensor([[-0.8804, -1.0117, -0.4331,  ..., -0.7622,  0.4949,  0.5039],
        [-0.4744, -0.4521, -0.3892,  ..., -0.1294, -0.0203, -0.3257],
        [-0.3982, -0.6108, -0.6694,  ...,  0.5439, -0.1365,  0.7422],
        ...,
        [-0.8076,  1.3965,  0.1659,  ..., -0.1390,  0.0667, -0.3110],
        [-0.9268, -0.6689, -0.6538,  ...,  0.7300,  0.6714,  0.4319],
        [-0.8208, -0.0661,  0.2571,  ...,  0.1813,  0.7002,  0.3538]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 82,  66,  37,  77,  78,  54,  20,  73,  90,  49,  91,  18,  33,  54,
         43,  39,  61,  32,  38,  80,  60,  96,  72,  24,  19,  82, 101,  97,
          3,  16,  77,  14,  74,  71,  54,  83,   6,  60,  89,  83,  23,  83,
         57,  78,  81,   6,  44,  31,  40,  22], device='cuda:0')
tensor([[-0.7217,  0.0915, -0.6104,  ..., -0.1060,  1.2178, -0.7817],
        [-0.7808, -0.5537,  0.4070,  ...,  0.1143, -0.7959,  0.1158],
        [-0.3252, -0.1779,  0.2432,  ..., -0.2438,  0.17

tensor([[-0.7295, -0.8008, -0.6406,  ..., -0.2325, -0.0826,  0.7212],
        [-0.6714,  0.6660,  0.4456,  ..., -0.2112, -0.0707, -0.1155],
        [-0.3040, -1.1123, -0.3350,  ...,  0.2898,  0.3574,  0.0133],
        ...,
        [-1.4023,  0.9604, -0.3982,  ..., -0.0468,  0.1687,  0.2433],
        [-0.2371, -0.5693, -0.6851,  ...,  0.6211,  0.3889,  0.1809],
        [-0.9810, -0.9336, -0.6055,  ...,  0.4214,  0.5205,  1.1631]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([73, 48, 12, 89, 44, 64, 81, 75,  7, 97, 33, 89, 62, 74, 78, 49, 21, 50,
        82, 44, 65, 22, 56, 90, 74, 56, 38, 42, 18, 19, 11, 94, 78, 61, 75, 26,
        43, 65, 44, 97, 45, 77, 72, 74, 72, 43,  1, 20, 75,  3],
       device='cuda:0')
tensor([[-0.8145, -0.6460, -0.6270,  ..., -0.3369, -0.4583,  0.9971],
        [-0.7739, -0.8062, -0.6753,  ..., -0.1754,  0.0587,  0.6821],
        [-0.2805, -0.4412, -0.3682,  ...,  0.3630,  0.9116, -1.2207],
        ...,
        [-0.4985, -0.779

       grad_fn=<AddmmBackward>) tensor([81, 31], device='cuda:0')
Training: Epoch 17 || Loss:   2.150 || Accuracy:  67.09%
tensor([[-0.7021,  0.9731,  0.5044,  ...,  0.0496,  0.0068,  0.4978],
        [-0.5698, -0.9966, -0.5054,  ...,  0.5312,  0.7847,  0.2397],
        [ 0.0392, -0.4644, -0.1515,  ...,  0.5264,  0.0333, -0.9585],
        ...,
        [-0.0189,  0.6411, -0.3113,  ..., -0.3999,  0.6050, -0.6401],
        [ 0.1225, -0.9150, -0.2590,  ...,  0.3152,  0.8799, -0.4111],
        [-0.5439, -0.7358, -0.2233,  ...,  0.4119,  0.5762,  0.0313]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.5161,  0.3147,  0.0667,  ..., -0.0348,  0.6304,  0.7114],
   

tensor([[-0.6831, -1.0361, -0.2520,  ...,  0.4509,  0.9458,  0.2166],
        [-0.3657, -0.9966, -0.4658,  ...,  0.6602,  0.9624, -0.6250],
        [-1.0244,  0.6196, -0.2686,  ...,  0.1296,  0.3237,  1.1543],
        ...,
        [-0.7363, -0.6304, -0.1199,  ..., -0.1432, -0.8271,  0.8623],
        [-0.3875, -0.2045, -0.9653,  ...,  1.5527, -0.0901, -0.9780],
        [-0.8716, -0.1023,  1.4297,  ..., -0.4343, -0.6426,  0.6602]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 92,  90,  77,   4,  98,  38,  74,  76,  40,  90,  96,  77,  77,  91,
         46,  75, 101,  12,  38,  48, 101,   5,  83,  26,  21,  75,  90,  96,
         35,  99,  72,   7,  96,  38,  87,  85,  61,  35,  85,  84,  56,  43,
         36,  43,  50,   2,  49,  73,  99,  71], device='cuda:0')
tensor([[-0.3503, -1.2910, -0.2927,  ...,  0.4919,  0.3535,  0.2776],
        [-0.3647, -0.6714, -0.6934,  ..., -0.1232,  0.1989,  0.4756],
        [-0.7134, -1.2471, -0.0220,  ...,  0.3997,  0.33

tensor([[-0.2537, -0.9858, -0.7300,  ...,  0.0721,  1.2334, -0.0171],
        [-0.7402, -1.0713, -0.6021,  ...,  0.7515,  0.1274,  0.7739],
        [-0.8408, -0.8228, -0.4146,  ..., -0.4194, -0.3977,  0.8091],
        ...,
        [-0.5088, -0.7529, -0.4099,  ...,  0.5112,  0.8276,  0.4346],
        [-0.7319, -0.4683, -0.9355,  ..., -0.0321,  0.4421, -0.0814],
        [-0.4607, -0.3340, -0.6416,  ...,  0.5195, -0.5991,  0.4519]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 72,  81,  73,  75,  76,  74,  59,  75,  49,  49,  91,  11,  90, 100,
         77,  14,  55,  99,  70,  49,  56,  64,  86,  58,  83,  48,  17,  59,
         58,  21,  59,  60,   5,  93,  91,  56,  43,  49,  94,  18,  32,  48,
         97,  33,  12,  13,  73,  41,  55,  18], device='cuda:0')
tensor([[-0.7451,  0.6431, -0.5249,  ...,  0.3560,  0.0759, -0.2817],
        [-0.7192, -0.6274,  0.2922,  ..., -0.1105,  0.0906, -0.4480],
        [-0.3149, -0.2737, -0.4121,  ...,  0.4380,  1.36

tensor([[-0.1968,  1.0400,  0.4805,  ...,  0.1499,  0.3652, -0.4836],
        [-0.2656, -0.9023, -0.3284,  ...,  0.7744, -0.2075, -0.6045],
        [-0.7705, -0.0948,  0.7139,  ...,  0.2781, -0.4460,  0.3059],
        ...,
        [-0.1587, -0.6250, -0.5347,  ...,  0.4702, -0.1553, -1.0068],
        [-0.4746,  0.6777,  0.8628,  ..., -0.2747,  0.0333,  0.3870],
        [-0.5752, -0.3350, -0.1948,  ...,  0.2476,  0.6646, -0.4277]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 8, 42, 38, 64, 48,  0, 64, 61, 98, 25, 82, 49, 56, 12,  8, 39, 15, 85,
         2, 33, 97, 42, 71, 70, 86, 81, 94, 46, 53, 33, 96, 40, 78, 63, 30, 47,
        90, 93, 82, 90, 77, 88, 74, 14, 96, 13, 38, 64, 11, 67],
       device='cuda:0')
tensor([[-0.2698,  0.0640, -0.8101,  ...,  0.4514,  0.9849, -0.1042],
        [-0.4023,  1.1650,  0.4675,  ...,  0.0189, -0.4609, -0.4136],
        [-0.5410, -1.3447, -0.4812,  ...,  0.8535,  1.3730,  0.6548],
        ...,
        [-0.0019, -0.121

tensor([[-0.3242,  0.0236, -0.7983,  ...,  1.4795, -0.1016, -0.7144],
        [-0.0031, -0.8584, -0.7710,  ...,  0.3264,  1.9609, -0.6699],
        [-0.6787, -0.9272, -0.6338,  ...,  0.1426,  0.1597,  0.8384],
        ...,
        [-0.3149, -0.6602, -0.5532,  ...,  0.4395,  1.2705, -0.2288],
        [-1.0479, -0.5396, -0.5229,  ...,  0.0999,  0.4717,  1.6250],
        [-0.2808, -0.8433, -0.3137,  ...,  0.3782,  0.3987,  0.3196]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  53,  69,  81,  61,  18,  78,  33,  39,  43,  65,  50,  90,  86,
         55,  15,  83,  32,  14,  77,  33,  82,  62,  24,  25,  84,  51,   2,
         56,  73,  89, 101,  14,  97,   9,  26,  49,  48,  85,  97,  49,  89,
         38,  77,  43,  64,  24,  72,  19,  72], device='cuda:0')
tensor([[-0.7788, -0.2588, -0.4678,  ...,  0.3806,  0.4597, -0.9521],
        [-0.1490, -0.6489,  0.2474,  ...,  0.0327,  0.4753, -0.6929],
        [-0.4199, -1.1074, -0.1675,  ...,  0.5654,  1.29

tensor([[-0.4502, -0.4424,  0.3428,  ...,  0.2395, -0.9575,  0.1708],
        [-0.2952, -0.3250, -0.3301,  ...,  0.5078,  0.1775, -0.5249],
        [-0.6855, -0.5249,  0.0038,  ..., -0.1511, -0.7583,  0.9829],
        ...,
        [-0.7661, -0.8315, -0.4521,  ...,  0.0681,  0.9419,  0.2783],
        [-0.2158, -0.6201, -0.3521,  ...,  0.0779,  1.1504, -0.6069],
        [-0.6558,  1.0850,  0.2847,  ...,  0.4568,  0.0633,  0.5464]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([59, 70, 73, 38, 90, 82, 69, 11, 97, 76, 26, 61, 31, 89, 74, 96, 63, 97,
        63, 77, 10,  6, 45, 91, 94, 53, 57, 82, 77, 76, 89, 51, 73, 15, 42, 61,
        17, 87, 74, 89, 13, 59, 63, 43, 59, 90, 31, 43, 49, 77],
       device='cuda:0')
tensor([[-0.1868, -0.8452, -0.6099,  ...,  0.3174,  0.4333, -0.4971],
        [-0.6748, -1.1113, -0.2874,  ..., -0.0091,  0.7119, -0.1847],
        [-0.7925, -1.0625, -0.5312,  ...,  0.3477,  0.9448,  0.7373],
        ...,
        [-0.0529, -0.502

tensor([[-0.2195, -0.6538, -0.1958,  ...,  0.7627,  0.2593, -0.5776],
        [-0.1731, -0.4934, -0.5552,  ...,  0.9136,  0.1224, -0.6670],
        [-0.4209,  0.4292, -0.6445,  ..., -0.0848,  0.6724,  0.0884],
        ...,
        [-0.6250, -0.4746,  0.1282,  ...,  0.0919,  0.2976,  0.9062],
        [-0.3176, -0.3047, -0.5386,  ...,  2.1309,  0.0815, -0.4053],
        [ 0.2478, -0.8516, -0.6699,  ...,  0.4141,  1.8330, -1.0957]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 51,  96,  94,  49,  20,  89,  55,  98,  78,  49,  32,  98,  64,  74,
         87,   8,  17,  77,  49,   5,  75,  24,  38,  89,  89,   2,  11,  76,
         42,  88,  64,  28,  59,  98,  89,   3,  84,  98,  87,  43,  33,  48,
         70,  40,  33, 100,  78,   4,  99,  53], device='cuda:0')
tensor([[-0.5757, -0.9800, -0.3894,  ...,  0.2375,  0.6304,  0.4722],
        [-0.4749,  0.2139,  1.4170,  ...,  0.0322, -0.3108,  0.2302],
        [-0.3643, -1.3350, -0.5078,  ...,  0.2808,  1.39

tensor([[-0.5981, -0.4878, -0.6030,  ..., -0.2009, -0.6943,  0.9951],
        [-0.3691, -0.4197, -0.5083,  ..., -0.0122, -0.7622, -0.3374],
        [-0.0861, -0.8892, -0.5161,  ...,  0.0864,  1.8330, -0.2922],
        ...,
        [ 0.0144, -0.0327, -0.9761,  ...,  1.1592,  1.3252, -0.6807],
        [-0.4551, -0.5898, -0.5015,  ...,  1.1875,  0.4741, -0.2783],
        [-0.6133, -0.6689, -1.0186,  ..., -0.1428,  0.6826, -0.0991]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 93,  68,  28,   6,  24,  76,  65,  27,  24,   3,  85,  21,  49,  62,
         53,   3,  62, 100,  48,  78,  77,  36,  79,  51,  94,  27,  54,  51,
         53, 100,  41,  58,  98,  93,  58,  94,   5,  49,  76,  64,   3,  74,
         12,  52,   0,  69,  10,  13,  90,  55], device='cuda:0')
tensor([[-0.0210, -1.0225, -0.7124,  ...,  0.3567,  2.5469, -0.2888],
        [-0.6880, -0.3726, -1.1201,  ..., -0.1844, -0.0475,  0.1255],
        [-0.6206, -0.9624, -0.7808,  ...,  0.1895, -0.12

tensor([[-0.6045, -0.7588, -0.7236,  ...,  0.1930,  1.2832,  0.6460],
        [ 0.0044, -0.4807, -0.0576,  ..., -0.0827,  1.8193, -0.3625],
        [ 0.0934, -1.1436, -0.3157,  ...,  0.0800,  1.5947,  0.0517],
        ...,
        [-0.0404, -0.2205, -0.2593,  ...,  1.0791, -0.3594, -0.2512],
        [-0.6509, -0.8320, -0.2102,  ...,  0.1305,  0.3740, -0.4868],
        [-0.1447, -0.2764, -0.5781,  ...,  0.3586, -0.1791, -0.4792]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 32,  49,  72,  92,  37,  92,  20,  63,  74,  86,  71,   5,  41,  62,
          5,  82,   7,  29,  53, 100,  90,  49,  20,  17,  87,  78,  88,  24,
         44,  84,  56,  83,  42,  26,  47,  57,  43,  62,  18,  56,  19,   6,
         51,  63,  98,  85,   6,  37,  43,  89], device='cuda:0')
tensor([[-0.4172, -0.1929, -0.0566,  ..., -0.1648, -0.0638, -0.4080],
        [-0.5386, -0.9966, -0.7485,  ...,  0.9307,  0.5630,  0.3816],
        [-0.3022, -0.8662, -1.0264,  ...,  0.2168, -0.08

tensor([[-0.6523, -0.7524, -0.1053,  ...,  0.4419,  0.2460, -0.0829],
        [-0.5752, -0.0912,  0.9307,  ..., -0.2307, -0.1483,  0.0844],
        [-0.1781, -1.1660, -0.3171,  ...,  1.1006,  0.9243, -0.2256],
        ...,
        [-0.0981, -0.2690, -0.4783,  ..., -0.2454,  1.4854, -0.8384],
        [-0.1243, -1.1064, -0.4919,  ..., -0.1906,  0.6621, -0.5879],
        [-0.2505, -0.4219, -0.0459,  ...,  0.5083,  0.5967, -0.1855]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([92, 47, 90, 12, 97, 16, 93, 49, 72, 32,  3, 94, 39, 82, 41, 76, 40, 72,
        59, 64, 49, 12, 90,  7,  8, 51, 52, 76, 75, 25, 73, 83, 44, 38, 82, 75,
        73, 10, 91, 77, 25, 95, 81, 21, 78, 47, 83, 49, 50, 49],
       device='cuda:0')
tensor([[-7.5049e-01, -3.0005e-01, -7.5879e-01,  9.6777e-01,  9.5154e-02,
          1.6272e-01, -1.0479e+00,  2.5269e-02, -3.3789e-01, -4.0918e-01,
         -2.1594e-01,  5.4199e-01, -2.5269e-01, -4.3457e-01,  3.4473e-01,
          1.0625e+00,  2.

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 93,  40,  97,  32,   2,  88,  25,  73,  25,  36,  23,  81,  82,  63,
        100,  75,  18,  68,  78,  53,   6,  66,  76,  77,  81,   3,  77,  78,
         53,  96,  38,  79,  41,   9,  34,   8,  81,  37,  63,  74,  98,  43,
          4,  86,  48,  64,  63,  37,  50,  91], device='cuda:0')
tensor([[-0.6152, -0.6475, -0.5532,  ...,  0.2896,  0.1261, -0.2354],
        [-0.1233, -0.8662, -0.4165,  ..., -0.0294, -0.0366, -0.9497],
        [-0.1885, -0.4202, -0.9111,  ..., -0.1564,  2.3574, -0.1857],
        ...,
        [-0.3269,  0.3201, -0.1898,  ...,  1.0488,  0.4995, -0.5312],
        [-0.2737, -0.2910, -0.4453,  ...,  2.4199, -0.5361, -0.2864],
        [-0.8667, -0.6128, -0.2274,  ...,  0.0778, -0.1940,  0.0570]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,  50, 100,  49,  45, 100,   7,  40,  99,  18,  44,  97,  24,  43,
         26,  84,  89,  54,  78,  67,  26,  84, 

tensor([[-0.3557, -0.9795, -0.9502,  ...,  0.1835,  1.4141,  0.1193],
        [-0.5640, -0.5718, -0.7485,  ...,  0.5049,  0.6890, -0.4973],
        [-0.2791, -0.6304,  0.1874,  ...,  0.2776, -0.5889,  0.2937],
        ...,
        [-0.6484, -0.9165, -0.4326,  ...,  0.4824, -0.1125, -0.2269],
        [-0.4265, -0.6011, -0.3918,  ..., -0.0812,  0.1855, -0.9082],
        [-0.3801,  0.7554,  1.0537,  ..., -0.0447, -0.1289, -0.4597]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 97,  70,  54,  74,  75,  38,  77,  84, 101,  34,  97,  99,  81,  87,
         12,  32,  90,  69,  53,  89,  73,  62,  71,  43,   8,  43,  64,  74,
        100,  99,  61,  76,  52,  70,  36,  97,  52,  57,  33,  80,  90,  81,
         65,   6,  84,  75,  61,  56,  45,   6], device='cuda:0')
tensor([[-0.1743, -0.9946, -0.7354,  ...,  0.1215,  1.3447,  0.0125],
        [-0.7227, -1.0693, -0.5723,  ...,  0.8511,  0.1127,  0.8540],
        [-0.8301, -0.8086, -0.3760,  ..., -0.4165, -0.44

tensor([[-0.4473,  0.9673,  0.7305,  ...,  0.0475,  0.4539, -0.2496],
        [-0.3350, -0.6489, -0.0751,  ...,  0.3225, -0.9346,  0.0922],
        [-0.5693,  0.2065, -0.5332,  ..., -0.7041,  0.4541,  0.5288],
        ...,
        [-0.4031, -0.5957, -0.3215,  ...,  0.1040, -0.1345, -0.2305],
        [-0.6709, -0.6162, -0.4343,  ...,  0.0600, -0.0252, -0.0318],
        [-0.4368, -0.7974,  0.0085,  ..., -0.0989,  0.6006,  0.5532]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([48, 54, 33, 79, 13, 93, 17, 43,  8, 60, 51, 74, 98, 83, 86, 93, 18, 60,
        67, 96, 46, 34, 96, 77, 78, 46,  8, 75, 72, 16, 90, 24, 24, 87, 70, 59,
        56, 50, 73, 51, 83, 82, 39, 80,  2, 59, 74, 74, 74, 26],
       device='cuda:0')
tensor([[-0.6694, -0.4294,  0.2932,  ...,  0.1188, -0.9985,  0.4438],
        [-0.4194, -0.8306, -0.4636,  ..., -0.0968,  2.2910,  0.5356],
        [-0.2764, -0.1899, -0.5820,  ...,  0.5479,  0.8281, -0.8936],
        ...,
        [-0.7085, -0.923

tensor([[-0.0664,  0.0547, -0.6616,  ..., -0.0666,  1.6934, -0.4988],
        [-0.3721, -0.9351, -0.5674,  ...,  0.5112, -0.1537,  0.6929],
        [-0.7827,  2.8926,  0.3779,  ..., -0.1025, -0.3235,  0.1644],
        ...,
        [-0.5264, -0.6743, -0.2483,  ...,  0.2478,  1.2432,  1.1592],
        [-0.7119, -0.4646, -0.3003,  ..., -0.0974, -0.9263,  0.1742],
        [-0.5723,  0.3083, -0.0803,  ...,  0.6094,  0.2625,  1.2061]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([67, 93,  1, 16, 78, 52, 14, 48,  6, 84, 49, 49, 43, 77, 73, 64, 12, 51,
        34, 22, 39, 33, 43, 73,  5, 52, 14,  6, 53, 64, 43, 76, 73, 77, 32, 72,
        38, 51, 62, 45, 89, 41, 90, 18, 89, 54, 61, 32, 73, 77],
       device='cuda:0')
tensor([[-0.9897,  1.0928, -0.7866,  ...,  0.0322,  0.3115,  0.1088],
        [-0.9106,  0.2605, -0.2212,  ...,  0.8101,  0.3735,  0.6797],
        [-1.0361, -1.0010,  0.0138,  ...,  0.8423,  0.1462,  1.4463],
        ...,
        [-0.6177, -0.544

0it [00:00, ?it/s]

Found best model
Training: Epoch 20 || Loss:   1.953 || Accuracy:  72.63%
tensor([[-0.6685,  1.1641,  0.7070,  ...,  0.1317, -0.0322,  0.5898],
        [-0.5020, -1.0488, -0.4719,  ...,  0.6841,  0.8892,  0.3145],
        [ 0.1980, -0.4236, -0.0759,  ...,  0.6685,  0.0682, -1.0684],
        ...,
        [ 0.1332,  0.8003, -0.2129,  ..., -0.4165,  0.7129, -0.6929],
        [ 0.2896, -0.9429, -0.1874,  ...,  0.4075,  1.0146, -0.4160],
        [-0.4707, -0.7656, -0.1724,  ...,  0.5239,  0.6343,  0.0760]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.4421,  0.4277,  0.2007,  ...,  0.0075,  0.7114,  0.8174],
        [ 0.1677, -0.4668,  0.1670,  ...,  0.7427,  

tensor([[-2.4524e-01,  1.1152e+00,  8.0566e-01,  ..., -1.5454e-01,
          1.0080e-03, -3.7646e-01],
        [ 1.3000e-01, -6.2354e-01, -8.3643e-01,  ...,  5.0537e-01,
          9.7900e-01, -4.5874e-01],
        [ 2.8198e-01, -3.0688e-01, -3.9160e-01,  ...,  8.2568e-01,
         -2.5171e-01, -1.1104e+00],
        ...,
        [-4.7485e-01, -4.7314e-01, -5.0098e-01,  ...,  1.7175e-01,
         -6.5063e-02,  6.4209e-01],
        [-1.6565e-01,  6.9727e-01, -4.2383e-01,  ...,  1.1902e-01,
         -1.1772e-02, -1.3367e-01],
        [-6.6846e-01,  3.0176e+00,  2.3096e-01,  ...,  1.2036e-01,
         -1.0040e-01,  1.6760e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 48,  49,  61,  50,  35,  15,  46,   5,  44,  51,  90,  71,  27,  39,
         37,  18,  60,  23,  37,  71,  83,  81,  34,  35,  82,  25,  99,  97,
         85,  89,  51,  82,   7,  48,  74,  19,  64, 101,  24,  86,  14,  34,
         93,  72,  72,  86,  88,  78,  66,   1], device='cuda:0'

tensor([[-0.3657, -1.1680,  0.0854,  ...,  0.3672,  0.3740,  0.6646],
        [-0.5811,  0.0815,  0.5088,  ...,  0.3691, -0.7617,  0.3267],
        [-0.1578, -0.4905, -0.2085,  ...,  0.6411,  1.6357, -0.3621],
        ...,
        [-0.3450,  1.4609, -0.1019,  ...,  0.6777,  0.3694, -0.3896],
        [ 0.0189, -0.7466, -0.0784,  ...,  0.3840, -0.6230,  0.3081],
        [ 0.2998, -0.9561, -0.4536,  ...,  0.4817,  0.8389, -0.4465]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([16, 57, 67, 74, 80, 31, 38, 37, 83, 82, 57, 73, 83, 75, 75, 62, 49, 41,
        81, 56, 33, 24, 91, 89, 77, 74, 93, 40, 45, 84, 54, 77, 18, 18, 53, 88,
        38, 55,  5, 80, 88, 71,  7, 97, 79, 96, 74, 77, 12, 49],
       device='cuda:0')
tensor([[-0.1583,  0.3203, -0.7949,  ...,  0.1656,  0.2639, -0.7012],
        [-0.5854, -0.6030, -0.6816,  ..., -0.1672,  0.4805, -0.0565],
        [ 0.0095, -0.1744, -0.1439,  ...,  0.4097,  0.9409,  0.3145],
        ...,
        [-0.4929, -0.885

tensor([[-0.6973,  0.8330, -0.4685,  ...,  0.4329,  0.0743, -0.2729],
        [-0.6597, -0.6299,  0.4434,  ..., -0.0591,  0.0889, -0.4485],
        [-0.1874, -0.2104, -0.3513,  ...,  0.5498,  1.5498, -0.3210],
        ...,
        [-0.1743,  0.2959, -0.2637,  ...,  1.0596, -0.1387, -0.1674],
        [-0.8047, -1.0303, -0.9346,  ...,  0.4075,  0.0685,  0.1857],
        [-0.6372, -0.7368, -0.3621,  ...,  0.2078,  0.0479,  0.6011]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  43,  49,  17,  22,  21, 101,  78,  85,  74,  38,  98,  54,  77,
         59,  15,  41,  31,  78,  20,  73,  38,  11,  92,  49,  40,  73,  19,
         49,  79,  85, 101,  47,  88,  37,  90,  53,  41,  66,  90,  89,  92,
         49,  99,  90,  88,  25,  96,  15,  78], device='cuda:0')
tensor([[-0.4324,  1.5098,  0.4102,  ...,  0.0667,  0.2734,  0.0281],
        [-0.2435, -0.9907, -0.4768,  ...,  0.8633,  0.3606, -0.0185],
        [-0.3809,  0.6958,  0.2512,  ...,  0.4604,  0.18

tensor([[-6.9775e-01,  2.2546e-01,  9.3506e-01,  ..., -3.3472e-01,
          1.4001e-01, -2.0972e-01],
        [ 1.2524e-01, -4.9683e-01, -6.7969e-01,  ...,  1.5820e+00,
         -6.1523e-01, -4.3530e-01],
        [-5.3857e-01, -1.9067e-01, -9.8096e-01,  ..., -2.6587e-01,
          5.1025e-01,  4.2236e-02],
        ...,
        [-4.1870e-01,  4.3433e-01,  1.1006e+00,  ..., -2.3901e-01,
         -5.7324e-01,  1.4229e-03],
        [-3.0981e-01,  1.6272e-01, -6.0742e-01,  ...,  4.5044e-02,
          1.3115e+00, -1.3584e+00],
        [-2.3718e-01,  2.3120e-01,  6.9482e-01,  ...,  2.0142e-01,
         -6.8555e-01, -1.6833e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 44,  99,  33,  52,  56,  82,  72,  65,  93,  78,   5,  27,  11,  80,
         33,  50,   6,  99,  34,  56,  54,  10,  14,  30,  68,  68,  85,  82,
          2,  47, 100,  59,  28,   7,  41,  90,   3,   9,  54,  53,  96,  90,
         99,  41,  35,   0, 100,   2,  91,  38], device='cuda:0'

tensor([[ 0.1338, -0.1642, -0.5347,  ...,  0.0161,  1.3740, -0.2418],
        [-1.0615, -0.7783, -0.1373,  ..., -0.2651,  0.8188,  2.2949],
        [ 0.0506, -0.4189, -0.3511,  ...,  0.6670, -0.3787, -1.1260],
        ...,
        [-0.8218, -0.8594, -0.2375,  ...,  0.2378,  0.4836,  0.0080],
        [-0.0919, -0.3860, -0.3506,  ..., -0.2458,  1.3633, -0.7188],
        [-0.7622, -0.4233, -0.2021,  ..., -0.3445, -0.8340,  0.7705]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 13,  80,  61, 100,  40,  33,  85,  49,  93,  75,  93,  62,  17,  57,
         78,  78,  77,  74,  24,  80,  33,  25,  29,  12,  78,   1,  85,  55,
         59,  90,  57,  51,  70,   0,  96,  15,  25,  18,  51,  47,  89,  82,
         38,  28,  90,  47,  53,  43,  28,  73], device='cuda:0')
tensor([[-0.6934, -0.5688, -0.0687,  ..., -0.2012, -0.1033, -0.0323],
        [-0.4604, -0.8062, -0.0739,  ...,  0.4158, -0.1533,  0.2991],
        [-0.5815,  1.0400,  0.8584,  ..., -0.2932,  0.14

tensor([[-0.2212,  0.1188, -0.8096,  ...,  1.7471, -0.0998, -0.7959],
        [ 0.1656, -0.8545, -0.7822,  ...,  0.4219,  2.1934, -0.7173],
        [-0.6323, -0.9707, -0.6431,  ...,  0.2032,  0.1669,  0.9775],
        ...,
        [-0.1829, -0.6714, -0.5303,  ...,  0.5532,  1.4346, -0.2050],
        [-1.0596, -0.4802, -0.4514,  ...,  0.1726,  0.4612,  1.8457],
        [-0.1710, -0.8711, -0.2627,  ...,  0.4910,  0.4683,  0.3899]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  53,  69,  81,  61,  18,  78,  33,  39,  43,  65,  50,  90,  86,
         55,  15,  83,  32,  14,  77,  33,  82,  62,  24,  25,  84,  51,   2,
         56,  73,  89, 101,  14,  97,   9,  26,  49,  48,  85,  97,  49,  89,
         38,  77,  43,  64,  24,  72,  19,  72], device='cuda:0')
tensor([[-0.6704, -0.1875, -0.4487,  ...,  0.4976,  0.4893, -1.0273],
        [ 0.0112, -0.6245,  0.3772,  ...,  0.1058,  0.5449, -0.7544],
        [-0.3186, -1.1406, -0.1122,  ...,  0.7080,  1.45

tensor([[-0.4946,  0.3918, -0.0792,  ...,  0.1000,  0.5645,  0.3264],
        [-0.8691, -0.1903, -0.2893,  ...,  0.3982, -0.1287,  1.0469],
        [-0.0618, -1.2715, -0.2842,  ...,  1.0312,  1.8096,  0.5303],
        ...,
        [-0.6021, -0.8896, -0.8193,  ...,  0.2351, -0.3545,  0.6240],
        [-0.3047,  0.7012,  0.2649,  ...,  0.2013, -0.4102,  0.5127],
        [-0.3901,  1.0381,  0.7212,  ...,  0.8335,  0.6709,  0.1153]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([79, 60, 84, 51, 43, 92, 80, 13, 94, 96, 54, 73, 89, 54, 48, 38, 66, 80,
        77, 67, 94, 60, 36, 15, 49, 49, 10, 51, 90, 52, 85, 45, 84, 79, 83, 58,
        59, 39, 40, 49, 58, 70, 24, 23, 51, 97, 47, 14, 77, 77],
       device='cuda:0')
tensor([[-3.4131e-01, -3.9575e-01,  1.0330e-02,  ...,  6.8701e-01,
         -2.8784e-01, -2.9224e-01],
        [-3.6084e-01, -1.3057e+00, -6.1475e-01,  ...,  1.0938e+00,
          8.6719e-01,  1.0840e+00],
        [-6.4453e-01, -6.4502e-01, -2.822

tensor([[ 2.8369e-01, -2.4146e-01,  6.4087e-02,  ...,  7.1729e-01,
         -4.1260e-01, -1.3887e+00],
        [-1.6321e-01,  4.3042e-01,  5.6396e-01,  ..., -2.8418e-01,
         -2.9126e-01, -4.5898e-01],
        [-2.3328e-01, -8.1396e-01, -5.7080e-01,  ...,  8.0957e-01,
          2.6147e-01, -5.8643e-01],
        ...,
        [-7.1729e-01, -1.1859e-01,  8.7842e-01,  ...,  3.5522e-01,
         -5.6738e-01,  2.0190e-01],
        [-1.5942e-01, -8.5107e-01, -5.9131e-01,  ...,  2.6302e-03,
          4.8584e-01, -4.2285e-01],
        [ 3.2495e-01, -4.0619e-02, -4.8975e-01,  ...,  6.0059e-02,
          2.8633e+00, -3.0151e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 61,  46,  98,  23,  38,  31,  49,  43,  48,   5,  58,   6,  32,  97,
         13,  43,  97,  98,  90,  33,  11,  43,  51,  82,  94,  90,  97,  71,
         88,  18,  96,  75,  34,  78,  14,  67,  66,  64,  13,  71,  83,  36,
          8,  21,  77, 100,  89,  57,  63, 100], device='cuda:0'

tensor([[-0.7031,  0.1008, -0.0092,  ..., -0.2041,  0.5249,  1.5684],
        [-0.3801, -1.3350, -0.6084,  ...,  0.9849,  0.7695,  0.6948],
        [-0.6187, -0.4299, -0.9634,  ...,  0.0269, -0.4697,  0.1818],
        ...,
        [ 0.1710, -0.3350, -0.5171,  ...,  0.5122, -0.0528, -0.5273],
        [-0.2424, -0.6167, -1.0420,  ...,  0.6851,  1.0215,  0.5044],
        [-0.6216,  0.5762,  1.3447,  ..., -0.3398,  0.0900, -0.0193]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([95, 81, 89, 66, 43, 20, 43, 85, 68, 63,  3, 32, 49, 37, 77, 97, 82, 53,
        82, 83, 71, 62, 90, 88, 90, 78, 78, 60, 32, 17, 90, 10, 51, 82, 96, 78,
        82, 93, 73, 51, 47, 65, 82, 63, 64, 41, 80, 68, 75, 52],
       device='cuda:0')
tensor([[-0.0785, -0.6216, -0.1300,  ...,  0.9062,  0.3171, -0.6011],
        [-0.0428, -0.4558, -0.5142,  ...,  1.0820,  0.1311, -0.7119],
        [-0.3123,  0.5669, -0.6274,  ..., -0.0519,  0.7441,  0.1378],
        ...,
        [-0.5547, -0.446

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 23,  73,  77,  52,  90,  89,  73,  26,  43,  59,  95,  91,  43,   7,
         11,  89,  58,  63,  75, 100,  76,   9,  73,  81,  68,   4,  44,  77,
         34,  92, 101,  71,  89,  48,  11,  90,  54,  73,  81,  79,  74,  40,
         46,  35,  48,  43,  45,  77,  55,  84], device='cuda:0')
tensor([[-0.2366,  0.6689,  0.9326,  ...,  0.2190, -0.2866, -0.8237],
        [-0.2891, -0.6904, -0.3337,  ...,  0.3896,  1.3477, -0.3162],
        [-0.0260, -1.2627, -0.1196,  ...,  0.7520,  1.7559,  0.1309],
        ...,
        [ 0.3179, -0.2681, -0.6377,  ..., -0.1345,  2.9609, -1.0039],
        [ 0.0885, -0.9082, -0.4702,  ...,  0.2915,  0.1814, -0.9258],
        [-0.3071, -0.2612, -0.8262,  ..., -0.2339,  0.7520, -0.6309]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  6,  90,  84,  31,   1,  73,  82,  65,   6,  88,  31,  97,  81,  50,
         78,  84,  56,  74,  90,  23,  83,  89, 

tensor([[-0.6426,  0.1096,  0.9492,  ..., -0.0874, -0.1003,  0.3364],
        [-0.0673, -0.3787, -0.4241,  ...,  1.9570,  0.4387, -0.6460],
        [-0.6855, -0.7588, -0.7480,  ..., -0.1331,  0.0080,  0.8506],
        ...,
        [-0.2742, -0.8618,  0.1007,  ...,  0.0980, -0.9204,  0.4731],
        [-0.2255,  1.2490,  0.9917,  ..., -0.0741, -0.0801, -0.7017],
        [ 0.1813, -0.1826, -0.1602,  ...,  0.4500,  1.8516, -1.0850]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 38,  99,  88,  28,  74,  83,  12,  81,  40,  77,  77,  76,  77,  91,
         75,  77,  23,  84,  40,  85,  50,  72, 101,  90,  38,  96,  32,  85,
         17,  49,  44,   1,  49,  49,  77,  28,  47,  46,  77,  50,  50,  40,
         81,  41,  70,  74,   8,  12,   6,  76], device='cuda:0')
tensor([[-0.3799,  0.8052,  0.0880,  ...,  0.4509,  0.7622,  0.8809],
        [ 0.1774, -0.6348, -0.4541,  ...,  0.0215,  1.1680, -0.7012],
        [-0.8369, -1.0186, -0.7266,  ..., -0.2532,  0.25

tensor([[-0.3403, -0.5264, -1.1006,  ...,  0.5190,  0.5566,  0.3582],
        [ 0.1628, -0.5254, -0.6714,  ...,  0.3235,  1.4805, -0.9272],
        [-0.3293,  2.0645,  0.5620,  ...,  0.1210, -0.0108, -0.2300],
        ...,
        [-0.3401,  0.1790, -0.1769,  ..., -0.3357,  0.5171,  1.6484],
        [-0.6777, -1.0791, -0.6045,  ..., -0.1851,  0.3181,  0.2566],
        [ 0.1453, -0.3782, -0.5000,  ...,  0.8286, -0.3643, -0.6548]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 75,  87,   8,  80,  56,  59,  72,  41,  80,  16,  85,  41,   5,  76,
         53,  71,  81, 101,  40,  64, 101,  43,  73,  54,  96,  11,  51,  11,
         64,  12,  56,  25,  11,   4,  63,  46,   6,  58,  24,  92,  77,  15,
         95,  90,  84,  66,  35,  95,  89,  68], device='cuda:0')
tensor([[-0.5576, -0.7778, -0.7329,  ...,  0.2671,  1.4053,  0.7686],
        [ 0.1642, -0.4646,  0.0075,  ..., -0.0446,  2.0547, -0.3760],
        [ 0.2759, -1.1777, -0.2432,  ...,  0.1274,  1.78

tensor([[-0.2581, -0.3130, -0.2756,  ...,  0.5293,  0.2257, -0.5601],
        [ 0.3369, -0.5132, -0.3464,  ...,  0.1177,  1.5869, -0.9111],
        [-0.7070,  0.0122,  1.3643,  ..., -0.1926, -0.7026,  1.0391],
        ...,
        [-0.3589, -0.9429, -0.7256,  ...,  0.6162,  1.5127,  0.0662],
        [-0.6772, -0.4941,  0.8223,  ...,  0.1538, -1.1953,  0.5562],
        [ 0.2292, -0.5229,  0.0072,  ...,  0.5352,  0.1880, -0.9619]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  50,  71, 100,   0,  26,  38,  65,  56,  59,  40,  70,  59,  97,
         41,  49,  77,  33,  40,  81,  24,  81,  89,  46,  65,  26,  11,  59,
         63,  75,  75,  65,  51,  12,  34,  39,  73,  46,  49,  77,  90,  49,
         47,  82,   4,  99,  43,  97,  59,  45], device='cuda:0')
tensor([[-0.0967, -0.8442,  0.1561,  ...,  1.1738,  0.0668, -0.0138],
        [-0.0285, -0.7227, -0.4050,  ...,  0.5107,  0.2152,  1.1475],
        [-0.7061,  0.2666, -0.8657,  ...,  0.7231, -0.47

tensor([[ 0.2883, -0.8960, -0.0063,  ...,  0.9189,  1.2666,  0.0343],
        [ 0.2206, -0.8657,  0.3560,  ...,  0.4143,  0.3037, -0.2390],
        [-0.0018,  0.1619,  0.6782,  ..., -0.1218,  0.1924, -0.3403],
        ...,
        [-0.1245, -0.7983, -0.0099,  ...,  0.7520, -0.0732, -0.3267],
        [-0.4634,  0.1547, -0.0601,  ...,  0.5571,  0.9531,  1.4863],
        [-0.1948, -0.6323, -0.6709,  ...,  0.7324,  1.7529,  0.2139]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 84,  12,  65,   8,  18,  22,  92,  18,  22,  66,  25,  39,  44,  45,
          7,  65,  51,  51,   9, 100,  89,  74,  98,  11,  48,  78,  89,   8,
         75,  74,  75,  90,  96,  64,  72,   3,  82,  83,  74,  64,   6,  72,
         94,   3,  33,  77,  41,  37,  95,  75], device='cuda:0')
tensor([[-0.2336, -0.3218,  0.0310,  ...,  0.2891,  0.8252,  0.4907],
        [-0.3843, -0.4424, -0.1681,  ...,  0.2163, -0.9058,  1.2490],
        [-0.4736,  0.8413,  1.2188,  ..., -0.0163,  0.25

tensor([[-0.2832, -1.1016, -0.3801,  ...,  0.9243,  0.1219, -0.0269],
        [-0.4224, -0.8496, -0.0760,  ...,  0.1985,  0.1017, -0.6030],
        [-0.6455, -0.6372, -0.3030,  ...,  0.9209,  0.4014,  0.8696],
        ...,
        [ 0.5332, -0.3618, -0.8799,  ...,  0.3535,  1.4805, -1.0449],
        [-0.5342, -0.7251,  0.5879,  ...,  0.2012, -0.8784,  0.5503],
        [-0.7314, -0.6934, -0.8130,  ...,  0.3384,  0.1372,  1.0322]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([90, 43, 81, 77, 53, 97, 16, 55, 77, 90, 86, 72, 87, 62, 98, 81,  9, 15,
        17, 65, 76, 74,  8, 59, 77, 75, 15, 97, 26, 74, 29, 55, 17, 97, 75, 57,
        41, 50, 84, 13, 66, 99, 14, 73, 43, 28, 67, 13, 54, 88],
       device='cuda:0')
tensor([[-0.1675, -0.7266, -0.5415,  ...,  0.3213, -0.2852,  1.4639],
        [-0.8813, -1.1230, -0.2844,  ..., -0.1315,  0.1467,  1.4141],
        [-0.4954, -1.1279, -0.8550,  ...,  0.6899,  2.0000,  0.8013],
        ...,
        [ 0.2598, -0.319

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,   6,  81, 101,  18,  11,  49,  75,  84,  24,  40,  27,  23,  87,
         87,  27,  74,  74,  34,  85,  97,  43,  98,  38,  54,  90,  79,  33,
         40,  62,  33,  96,  41,  13,  73,  72,  99,  97,  77,  74,   3,  54,
         87,  45,  80,  73,  77,  43,  36,  45], device='cuda:0')
tensor([[-0.9922, -0.8311,  0.1351,  ..., -0.0688,  0.5034,  3.0684],
        [-0.3635, -1.3828, -0.3521,  ...,  1.0137,  0.9458,  0.1433],
        [ 0.1659, -1.1328, -0.4670,  ...,  0.3950,  1.3252, -0.0840],
        ...,
        [-0.7231, -0.8135, -0.1210,  ...,  0.1085,  0.1763,  1.0947],
        [-0.7080,  1.5830,  1.3555,  ...,  0.1014, -0.3508,  0.6279],
        [-0.5098, -1.1074, -0.8369,  ...,  0.0691,  0.7075, -0.5723]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([101,  41,  42,  62,  13,  49,  65,  53,  43,  34,  82,   6,  57,  82,
         24,  89,  18,  27,  12,  90,  17,  77, 

tensor([[-0.8013, -0.2751, -0.2664,  ...,  0.8809, -0.1332,  0.8525],
        [-0.5991, -0.8218, -0.3499,  ...,  0.0183,  0.3806,  0.8223],
        [-0.6763, -0.0959, -0.4988,  ...,  0.0411,  1.0371,  0.8843],
        ...,
        [-0.4773, -1.1357, -0.2488,  ..., -0.3850,  1.7773,  0.2979],
        [-0.2003,  0.1310, -0.8472,  ...,  0.4004, -0.6406, -0.3977],
        [-0.2059,  0.6431,  0.3115,  ...,  0.0334,  0.0850, -0.1940]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([81, 41, 92, 80,  6, 83,  1, 25, 25,  4, 83, 26, 96, 49, 74, 74, 85, 67,
        74, 70, 83, 52, 67, 89, 59, 37, 34, 29, 69, 62, 12, 15, 43, 67, 60, 64,
        39,  5, 77, 28, 54, 36, 73, 14, 67, 78,  8, 86, 18, 77],
       device='cuda:0')
tensor([[-0.2250,  0.8066, -0.2976,  ...,  1.5049, -0.2069, -0.3965],
        [-0.8667,  0.2101,  1.4248,  ..., -0.5444, -0.5254,  0.8071],
        [-0.2512, -0.7090, -0.6357,  ...,  0.4309,  0.8799, -0.4814],
        ...,
        [-0.7759, -1.420

tensor([[-0.3162,  1.4297,  0.2393,  ...,  0.5303, -0.2910, -0.6099],
        [-0.5718, -0.6187,  0.0641,  ...,  0.0241,  0.6553, -0.3528],
        [-0.6099,  0.6411,  0.3264,  ..., -0.0974,  0.4971,  0.5806],
        ...,
        [-0.7075, -1.0020, -0.1727,  ...,  0.2222, -0.1337,  0.8237],
        [ 0.3826, -0.4194,  0.1431,  ...,  0.2423,  0.8784, -1.0566],
        [-0.6836, -1.1826, -0.7690,  ...,  1.0967,  0.4380,  1.1992]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([34, 43, 29, 14, 23, 11, 48, 49, 30, 75, 82, 98, 82, 57, 90,  9, 80, 33,
        81, 66, 33, 71, 21, 49, 91,  3, 79, 26, 86, 23, 89, 71, 98, 42, 26, 84,
        49, 25, 52, 73, 54, 74, 43, 53, 97, 91, 99,  4, 50, 81],
       device='cuda:0')
tensor([[-0.1794, -0.4741, -0.4011,  ...,  2.7812,  0.1126, -0.0659],
        [-0.1051, -1.0322, -0.4426,  ..., -0.1699,  1.1729, -0.5684],
        [-0.7485, -0.1812, -0.5430,  ...,  0.2795,  0.7578,  1.6143],
        ...,
        [-0.1494,  1.103

tensor([[-0.3135,  2.1797,  0.5103,  ...,  0.0255,  0.6133, -0.1345],
        [-0.7939,  0.6919,  1.2822,  ..., -0.4250, -0.7310,  0.2859],
        [-1.0898,  0.1759,  0.7598,  ..., -0.1532,  0.3726,  0.4260],
        ...,
        [-0.1689, -0.5767,  0.2200,  ...,  0.0152,  0.9155, -0.5732],
        [-1.0654, -0.4956, -0.2927,  ..., -0.0243,  0.5239,  2.4434],
        [ 0.1436, -0.6367, -1.2217,  ...,  0.2340,  1.4277, -1.3223]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 31,  52,  22,  24,  44,  71,  77,  75,   5,  78,  64,  94,  49, 101,
        101,  43,  98,  62,  87,  32,  37,  72,  84,  90,  42,  28,  73,  92,
         98,  96,  73,  20,  68,  12,   7,  38,  90,  89,  80,  93,   6,  85,
         85,  83,  31,  77,  19,  42,  19,  36], device='cuda:0')
tensor([[-0.3079, -0.9185, -0.5889,  ...,  0.2749,  1.5986,  0.0927],
        [-0.5288, -0.5366, -0.1316,  ..., -0.0743, -0.4355,  0.8960],
        [ 0.2139, -0.1299, -0.5977,  ...,  0.5366,  0.97

tensor([[-6.9629e-01, -5.4736e-01,  1.8326e-02,  ...,  1.0777e-03,
         -3.5010e-01,  1.5637e-01],
        [-3.9868e-01, -1.4375e+00, -1.6708e-02,  ...,  4.8047e-01,
          6.2256e-01, -4.7437e-01],
        [-7.2119e-01, -8.8184e-01,  1.5674e-01,  ...,  2.9272e-01,
          6.4111e-01,  1.1748e+00],
        ...,
        [-5.7861e-01, -7.5732e-01, -9.7998e-01,  ..., -3.0322e-01,
          8.5303e-01,  2.7295e-01],
        [-7.4414e-01, -1.0391e+00,  5.3467e-01,  ...,  1.4856e-01,
         -1.4111e-01,  1.7832e+00],
        [-3.4847e-03, -1.1484e+00, -7.6074e-01,  ..., -1.2054e-01,
          1.8389e+00,  5.5786e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([74,  5, 92, 83, 69, 43, 98, 47,  6, 74,  8, 40, 56, 81, 54, 20, 83, 55,
        62, 50, 50, 80, 40, 60, 72, 90, 78, 74, 33, 82, 97, 47, 45,  5, 69, 99,
        11, 15, 23, 99, 42, 69, 11, 34, 81, 95, 41, 40, 16, 72],
       device='cuda:0')
tensor([[-0.0660,  1.5557,  1.1309,  ...,  0.1737

tensor([[-0.0714, -1.1592, -0.3447,  ...,  0.2009,  1.2148,  0.6060],
        [ 0.2316, -0.2026,  0.2815,  ...,  0.8818,  1.3350, -0.3948],
        [-0.2184,  1.3154,  0.3972,  ..., -0.2725,  0.1351,  0.4612],
        ...,
        [ 0.0999, -0.7637, -0.5645,  ...,  0.4949,  1.9746,  0.0770],
        [-0.2029, -0.5186, -1.0781,  ...,  0.7236,  1.0029, -0.3403],
        [ 0.2393, -1.3184, -0.3271,  ...,  1.4668,  1.1699, -0.4968]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([72, 49, 29, 87, 32, 48, 71, 31,  6, 81, 74, 56, 43, 48, 74, 48, 97, 51,
        82, 73, 32,  6, 55, 14, 81, 66, 97, 60, 90, 54, 76, 57, 12, 37, 22, 45,
        40, 59, 97, 19, 82, 54, 49, 42, 37, 67, 43, 72, 75, 90],
       device='cuda:0')
tensor([[ 0.1019,  1.2949,  0.2969,  ..., -0.3875,  0.3491, -0.6826],
        [-0.5845,  0.1685,  2.3223,  ..., -0.4058, -1.0322,  0.6597],
        [-0.5215, -0.4561,  0.1681,  ...,  0.2327,  0.2805, -0.8716],
        ...,
        [-1.0195, -1.192

tensor([[-8.8574e-01, -1.0820e+00, -2.9321e-01,  ..., -8.5986e-01,
          5.8350e-01,  7.7783e-01],
        [-2.5903e-01, -3.5645e-01, -2.2681e-01,  ..., -5.8044e-02,
          2.0733e-03, -3.1104e-01],
        [-1.5894e-01, -5.4834e-01, -5.7471e-01,  ...,  8.3301e-01,
         -1.3708e-01,  1.0645e+00],
        ...,
        [-6.4258e-01,  2.1094e+00,  5.4053e-01,  ..., -1.7502e-02,
          1.9791e-02, -3.5449e-01],
        [-7.9785e-01, -6.9043e-01, -5.7617e-01,  ...,  1.0645e+00,
          8.3057e-01,  6.9727e-01],
        [-6.6016e-01,  1.0791e-01,  6.0303e-01,  ...,  3.6548e-01,
          8.4033e-01,  5.3857e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([ 82,  66,  37,  77,  78,  54,  20,  73,  90,  49,  91,  18,  33,  54,
         43,  39,  61,  32,  38,  80,  60,  96,  72,  24,  19,  82, 101,  97,
          3,  16,  77,  14,  74,  71,  54,  83,   6,  60,  89,  83,  23,  83,
         57,  78,  81,   6,  44,  31,  40,  22], device='cuda:0'

tensor([[-0.4792, -1.2119, -0.5562,  ...,  0.2330,  0.1930,  1.1992],
        [ 0.1671, -0.7793, -0.3499,  ...,  0.8359,  1.1914, -1.1016],
        [-0.3821, -1.1836, -0.2487,  ...,  0.3022,  1.3721, -0.6250],
        ...,
        [-0.3909, -1.6484, -0.3186,  ...,  0.2563,  1.5938,  0.8848],
        [-0.3367, -0.6035, -0.2075,  ...,  0.0598,  2.5273,  0.1389],
        [-0.2742,  0.7192,  0.1763,  ...,  0.8340,  1.6895, -0.0747]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 85,  53,   5,  94,  71,  47,  40,  56,   7,   1,  41,  50,  78,  96,
         32,  81,  73,   0,  99,  69,   9,  72,  51,  97,  22,  53,  77,  54,
         81, 100,  52,  59,  43,  82,  82,  73,  77,  33,  33,  94,   9,  85,
         75,  97,  49,  42,  50,  72,  76,  77], device='cuda:0')
tensor([[-0.5132, -0.4299, -0.3567,  ...,  0.4006,  0.6265,  0.2908],
        [-0.2495, -0.4397,  0.3289,  ...,  0.9805,  0.4438, -0.8154],
        [-0.0645,  1.0947, -0.1515,  ..., -0.3411,  0.43

tensor([[-0.3770, -0.7842, -0.6260,  ..., -0.1140,  1.1846,  0.8276],
        [-0.7227, -0.8589, -0.3760,  ...,  0.6431,  0.9565,  0.0774],
        [-0.2554, -0.9106, -0.4070,  ...,  0.2837,  1.3506, -0.9302],
        ...,
        [-0.0502,  2.6172,  0.7974,  ...,  0.3118, -0.0558, -1.1172],
        [-0.3394, -1.5322, -0.2086,  ...,  1.4629,  0.6582,  0.4717],
        [-0.6357, -0.6738,  0.3928,  ...,  0.0682, -0.1477,  0.2361]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 32,  92,   5,  57,  80,   5,  43,  86,   9,  84,  96,  78,  53,  64,
         22,  57,  18,  28,  71,  73,  77,  15,  86,  99,  84,  18,  57,  82,
         51,  43,  97,  81,  53,   2, 101,   9,  15,  75,  59,  18,  24,  54,
         88,  51,  80,  89,  30,   8,  81,  74], device='cuda:0')
tensor([[ 0.1500, -0.3855, -0.3862,  ...,  1.3145,  0.9561, -1.1436],
        [ 0.2976, -0.9106, -0.0246,  ...,  0.1758,  1.0156, -0.8696],
        [-0.2391, -0.4993, -0.8364,  ...,  0.1614,  2.03

tensor([[-7.2266e-01, -1.8604e-01, -7.1729e-01,  1.2471e+00,  2.3853e-01,
          1.9385e-01, -1.2383e+00,  1.7664e-01, -3.3740e-01, -3.6011e-01,
         -7.1777e-02,  6.2402e-01, -2.3596e-01, -3.8184e-01,  4.6875e-01,
          1.3271e+00,  4.3872e-01, -5.2393e-01,  6.9092e-01,  1.1826e+00,
          1.2781e-01,  4.1187e-01, -2.3633e-01, -6.8652e-01, -2.3975e-01,
         -4.4849e-01, -3.0737e-01,  9.2285e-01,  6.8359e-02, -3.1250e-01,
         -3.3887e-01, -8.3252e-01, -4.5996e-01, -1.9031e-01, -4.7607e-01,
          1.0193e-01,  2.2266e-01, -1.4026e-01,  8.8916e-01, -6.3623e-01,
          7.1826e-01,  7.5098e-01, -7.2656e-01, -8.5010e-01, -6.8115e-01,
         -8.5449e-01,  3.3264e-02, -2.4097e-01, -1.3506e+00, -1.1975e-01,
         -1.0312e+00, -6.8896e-01,  2.0459e-01, -5.7617e-01,  4.9609e-01,
          1.0175e-01,  3.9734e-02, -2.2510e-01,  3.7354e-01, -4.4336e-01,
         -2.2461e-01, -2.4902e-01, -3.4698e-02, -2.8833e-01, -8.6475e-01,
         -1.4893e-01, -6.7578e-01, -5.

tensor([[ 0.1344, -0.1763, -0.2186,  ..., -0.0326,  3.1348, -0.5039],
        [ 0.7095, -1.1650, -0.1925,  ...,  0.9634,  1.0791, -0.9834],
        [ 0.1746, -0.5464, -0.8174,  ...,  0.4380,  0.3420, -0.6968],
        ...,
        [-0.4395, -0.9111, -0.2798,  ...,  0.9131,  1.2080,  0.3411],
        [-0.5327, -0.5273, -0.3452,  ..., -0.2651, -0.3904,  1.3896],
        [-0.3687, -1.3037, -0.3367,  ...,  0.6714,  1.0615, -0.2391]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([100,   0,  55,  56,  77,  32,  44,  19,  33,  19,  53,  74,  97,  56,
         22,  73,  49,   3,  86,  97,  71,  23,  92,  50,  75,  53,  67,  43,
         78,  76,  76,  38,  54,  59,  98,  33,  70,  11,  25,  93,  96,  77,
         36,  50,  86,  93,  49,   3,  73,   3], device='cuda:0')
tensor([[-0.6421, -0.0412, -0.5752,  ..., -0.1006, -0.3904,  0.0524],
        [ 0.0288,  1.9043,  0.9697,  ...,  0.1635, -0.4534, -0.8052],
        [-0.2703,  0.6978,  2.0039,  ..., -0.0617, -0.61

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 71,  18,  77,  49,  74,  45,  59,  37,  36,  92,  16, 101,  50,  29,
         10,  77,  75,  49,  23,  77,  45,  94, 100,  10,  54,  49,  85,  45,
         51,   6,  43,  54,  49,  27,  40,  88,  77,  78,  64,  43,   7,  96,
         84,  44,  89,   2,  59,  36,  22,  27], device='cuda:0')
tensor([[-0.1052,  1.3662,  0.9976,  ..., -0.0967,  0.0085, -0.3967],
        [ 0.3486, -0.6060, -0.8667,  ...,  0.6606,  1.1670, -0.4915],
        [ 0.5073, -0.2690, -0.3328,  ...,  1.0195, -0.2605, -1.2285],
        ...,
        [-0.4055, -0.4014, -0.4075,  ...,  0.2583, -0.0767,  0.7812],
        [ 0.0106,  0.8774, -0.3562,  ...,  0.1995,  0.0090, -0.1110],
        [-0.6064,  3.5430,  0.3606,  ...,  0.2150, -0.1277,  0.1963]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 48,  49,  61,  50,  35,  15,  46,   5,  44,  51,  90,  71,  27,  39,
         37,  18,  60,  23,  37,  71,  83,  81, 

tensor([[ 0.0341, -1.0088, -0.7324,  ...,  0.2522,  1.6152,  0.0866],
        [-0.6660, -1.0557, -0.4829,  ...,  1.1084,  0.0890,  1.0439],
        [-0.7935, -0.7637, -0.2751,  ..., -0.3975, -0.5386,  1.1162],
        ...,
        [-0.3542, -0.8062, -0.2761,  ...,  0.7930,  1.0332,  0.6963],
        [-0.5669, -0.4155, -0.9673,  ...,  0.1008,  0.5659,  0.0024],
        [-0.3176, -0.2157, -0.6187,  ...,  0.8325, -0.7588,  0.6343]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 72,  81,  73,  75,  76,  74,  59,  75,  49,  49,  91,  11,  90, 100,
         77,  14,  55,  99,  70,  49,  56,  64,  86,  58,  83,  48,  17,  59,
         58,  21,  59,  60,   5,  93,  91,  56,  43,  49,  94,  18,  32,  48,
         97,  33,  12,  13,  73,  41,  55,  18], device='cuda:0')
tensor([[-6.2598e-01,  1.0820e+00, -3.8501e-01,  ...,  5.4346e-01,
          7.2510e-02, -2.6587e-01],
        [-5.6836e-01, -6.3232e-01,  6.4990e-01,  ...,  1.9943e-02,
          9.2957e-02, -4.4

tensor([[ 0.4080, -0.0277,  0.0709,  ...,  0.8667,  1.7559, -0.5166],
        [-0.3452, -0.7129, -0.7139,  ...,  0.9370,  1.5879,  0.8359],
        [-0.4636, -0.7368, -0.6782,  ...,  0.7861,  0.7417, -0.6631],
        ...,
        [-0.0927,  0.5503, -0.3818,  ...,  1.0234, -0.0318, -0.9868],
        [ 0.3162, -0.5200,  0.1801,  ...,  0.6392,  1.4023, -0.9131],
        [ 0.1470, -0.6118,  0.3066,  ...,  0.2883, -0.0122, -0.9692]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 83,  75,  49,  49,  91,  63,   8,   5,  96,  87,  82,   6,  44,  39,
         62,  79,  82,  76,  76,  94,  76,  96,  90,   3,  42,  79,  14,  84,
         48,   4,  74,  76,  99,  20,  74,  41,  59,   4,  49, 100,   3,   1,
         63,  77,  63,  96,  48,  96,  87,  64], device='cuda:0')
tensor([[-0.1066, -0.8047, -0.0471,  ...,  0.5405,  1.7412,  0.6948],
        [-0.0809, -1.1455,  0.0583,  ...,  0.6436,  2.3613,  0.7109],
        [-0.1483,  0.5728,  0.6792,  ..., -0.4355, -0.30

tensor([[ 0.1312,  0.2146,  0.5791,  ..., -0.5645,  1.1064,  0.2390],
        [-0.0658, -0.3262, -0.6147,  ...,  0.4990, -1.1689, -0.6748],
        [-0.1849, -0.8970, -0.5713,  ..., -0.1356,  1.8906, -0.2446],
        ...,
        [ 0.0405, -0.8042, -0.6304,  ...,  0.5205,  0.0524,  0.0360],
        [-0.2200, -1.1553, -0.4910,  ...,  0.2876,  1.3223, -0.3740],
        [-0.2666,  1.3281,  1.1465,  ..., -0.0101, -0.2959, -0.2260]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 26,  68,  72,  89,  97,  49,  57,  13,  11,  38,  54,  87,  36,  33,
         77,  90,  82,  73,  36,   1,  30,  73,  75,  76,  35,  56,  14,  38,
         36,  81,  78,  94,  33,  23,  76,  49,  40,  37,  64,  30,  16,  49,
         48,  11,  38,  96, 101,  63,  97,  48], device='cuda:0')
tensor([[-0.6021,  0.3083,  1.2168,  ..., -0.3123,  0.1686, -0.2013],
        [ 0.2913, -0.4473, -0.6411,  ...,  1.9346, -0.6992, -0.4858],
        [-0.4136, -0.1267, -0.9844,  ..., -0.2461,  0.55

tensor([[-0.9546,  1.4180, -0.7139,  ...,  0.1262,  0.3381,  0.1385],
        [-0.8628,  0.4360, -0.0786,  ...,  1.0244,  0.4185,  0.8071],
        [-0.9937, -1.0166,  0.1860,  ...,  1.1201,  0.1207,  1.7549],
        ...,
        [-0.4749, -0.5625, -0.6870,  ...,  0.2983,  0.7930, -0.7632],
        [ 0.1588,  1.1768,  0.0779,  ..., -0.3396,  0.9087, -0.8188],
        [-0.0040, -0.4250, -0.8208,  ...,  1.0488,  0.9038,  0.0559]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 20,  77,  16,  27,  74,  46,  74,  45,  19,  45,  54,  81,  54,   3,
         13,  54,  36,  89,  82,  65,  54,   9,  15,   9,  29,  75,  73,  91,
         82,  27,  82,  73, 101,   3,  40,  51,  83,  23,  41,   4,  89,  60,
         74,  23,  46,  56,  11,   5,  17,  14], device='cuda:0')
tensor([[-0.0692,  0.2423, -0.8198,  ...,  2.1074, -0.0942, -0.9043],
        [ 0.4021, -0.8452, -0.7842,  ...,  0.5620,  2.4961, -0.7798],
        [-0.5625, -1.0176, -0.6421,  ...,  0.2903,  0.18

tensor([[-0.3022,  0.6167,  0.4333,  ...,  0.0821,  0.8521,  0.9854],
        [ 0.4448, -0.4055,  0.3455,  ...,  0.9531,  0.3618, -1.0029],
        [-0.8706, -0.8394, -1.0244,  ...,  0.9990,  1.2236,  0.5962],
        ...,
        [-0.3657, -0.9331, -0.1721,  ...,  1.0186,  0.9126,  1.5137],
        [-0.4045, -1.0449, -0.1471,  ...,  0.8164,  0.1792, -0.2285],
        [-0.1318, -0.0564, -0.3938,  ...,  0.4597,  0.4412, -0.6040]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([77, 45, 97, 10, 78, 39, 81, 97, 68, 67, 96, 72, 89, 49, 53,  5, 77, 20,
        74, 55, 68, 78, 33, 70, 12, 37, 31, 58, 73,  0, 79, 97, 44, 53, 74, 31,
        73, 11, 50, 14, 67, 61, 23, 26, 84, 47, 88, 41, 56, 96],
       device='cuda:0')
tensor([[-0.0123, -1.0459, -0.2964,  ...,  0.0042,  0.9556, -0.0651],
        [-0.3838, -0.8711, -0.5420,  ...,  0.3950,  1.2158,  0.9268],
        [ 0.3174, -1.0107, -0.2532,  ...,  0.4194,  2.2402, -0.2023],
        ...,
        [-0.3582, -1.150

tensor([[-0.4084,  0.5200,  0.0335,  ...,  0.1843,  0.6499,  0.4143],
        [-0.8428, -0.1160, -0.2057,  ...,  0.5151, -0.1705,  1.2207],
        [ 0.1208, -1.3164, -0.1802,  ...,  1.2480,  2.0742,  0.6313],
        ...,
        [-0.5483, -0.9150, -0.8481,  ...,  0.3040, -0.3975,  0.7437],
        [-0.1923,  0.9155,  0.4351,  ...,  0.2898, -0.4624,  0.6187],
        [-0.2681,  1.2783,  0.9731,  ...,  1.0059,  0.7764,  0.1635]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([79, 60, 84, 51, 43, 92, 80, 13, 94, 96, 54, 73, 89, 54, 48, 38, 66, 80,
        77, 67, 94, 60, 36, 15, 49, 49, 10, 51, 90, 52, 85, 45, 84, 79, 83, 58,
        59, 39, 40, 49, 58, 70, 24, 23, 51, 97, 47, 14, 77, 77],
       device='cuda:0')
tensor([[-0.2145, -0.3608,  0.1560,  ...,  0.8613, -0.2949, -0.3203],
        [-0.2795, -1.3682, -0.5928,  ...,  1.3408,  1.0078,  1.2373],
        [-0.5923, -0.6172, -0.1852,  ..., -0.3047, -0.3770,  0.9316],
        ...,
        [ 0.1254, -0.647

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([26, 45,  5, 26, 73, 19, 84, 10, 69, 18, 74, 40, 85, 96, 75, 59, 32,  4,
        89, 90,  7, 32, 32, 71, 84, 72, 45, 68, 34, 68, 67, 32, 53, 73, 36, 92,
        88, 43, 83, 77, 49, 50, 11, 32, 84, 99, 11, 21, 70, 60],
       device='cuda:0')
tensor([[-6.4209e-01,  2.6123e-01,  1.6089e-01,  ..., -1.6406e-01,
          5.5176e-01,  1.8145e+00],
        [-2.7222e-01, -1.3770e+00, -5.4834e-01,  ...,  1.2168e+00,
          8.7744e-01,  8.0762e-01],
        [-5.3369e-01, -3.8940e-01, -9.6094e-01,  ...,  1.3867e-01,
         -5.2246e-01,  2.3193e-01],
        ...,
        [ 4.0088e-01, -2.8296e-01, -4.7729e-01,  ...,  6.7822e-01,
         -4.0161e-02, -5.7861e-01],
        [-1.1627e-01, -5.7178e-01, -1.0596e+00,  ...,  8.3301e-01,
          1.1641e+00,  5.8398e-01],
        [-5.4150e-01,  7.3877e-01,  1.6543e+00,  ..., -2.9224e-01,
          1.2311e-01, -7.5388e-04]], device='cuda:0', dtype=torch.float16,
       grad

tensor([[-1.0430,  0.8550, -0.5845,  ...,  0.9536, -0.3982,  0.7842],
        [-0.3674, -0.3682, -0.7710,  ..., -0.1715,  1.3350,  0.5527],
        [-0.3301,  0.2289, -0.0280,  ...,  0.1266,  0.6958,  0.5449],
        ...,
        [-0.5732,  0.2015,  1.3330,  ...,  0.1304, -0.0981, -0.0908],
        [-0.7578,  0.2744,  1.5078,  ..., -0.0765, -0.2629,  0.5806],
        [-0.2817,  0.1777,  0.4456,  ...,  1.0576,  0.5859,  0.9165]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 58,  32,  23,  93,  72,  19, 100,  52,  75,  26,  73,  57,  56,  51,
         33,  34,  74,  97, 100,   5,  10,  43,  54,  97,  69,  30,  49,  43,
         82,  95,  97,  34,  81,  92,  76,  39,  78,  34,  11,  41,  16,  18,
         41,  43,  44,   5,  80,  47,  52,  77], device='cuda:0')
tensor([[-0.7139,  0.3176,  1.7676,  ..., -0.4858, -0.3726,  0.6997],
        [-0.6470, -0.4998,  0.0794,  ..., -0.0215, -0.0671,  1.6514],
        [ 0.2766, -0.5122, -0.4597,  ...,  0.7319,  3.20

tensor([[-6.7627e-02,  8.4570e-01,  1.1523e+00,  ...,  3.1738e-01,
         -2.9272e-01, -8.9990e-01],
        [-1.3904e-01, -7.1631e-01, -2.3938e-01,  ...,  5.2832e-01,
          1.5586e+00, -3.2910e-01],
        [ 1.6638e-01, -1.3066e+00,  7.3385e-04,  ...,  9.3311e-01,
          1.9980e+00,  1.9482e-01],
        ...,
        [ 5.6494e-01, -2.0862e-01, -6.1230e-01,  ..., -8.3008e-02,
          3.3828e+00, -1.0869e+00],
        [ 3.0054e-01, -9.1553e-01, -4.1577e-01,  ...,  4.1187e-01,
          2.2107e-01, -1.0186e+00],
        [-1.4514e-01, -1.9067e-01, -8.1787e-01,  ..., -1.7212e-01,
          8.3936e-01, -6.9385e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([  6,  90,  84,  31,   1,  73,  82,  65,   6,  88,  31,  97,  81,  50,
         78,  84,  56,  74,  90,  23,  83,  89,  27,  83,  82,  24,  12,  84,
         56,  10,  23, 100,  96,   9,  70,  65,  72,  19,  40,  69,  49,  49,
         67,  59,  82,  30,  65, 100,  50,  28], device='cuda:0'

tensor([[-0.0421,  0.0894, -0.2520,  ...,  0.8604, -0.1202, -0.7568],
        [-0.5449,  1.3613,  1.0654,  ...,  0.2598, -0.0604,  0.6382],
        [-0.6084, -0.2515,  1.0371,  ...,  0.1804, -0.9473,  0.4480],
        ...,
        [-0.4087, -1.2061, -0.6655,  ...,  0.1697, -0.0846,  0.2377],
        [-0.8555,  0.3708,  1.4043,  ..., -0.1960, -0.3420,  0.4915],
        [-0.5942, -0.6890, -0.2317,  ...,  0.8452,  0.3525,  0.8335]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  24,  59,  77, 101,   8,  80,  82,  80,  83,  89,  11,  46,  76,
         96,  77,  97,  78,  68,  77,  54,  55,  76,  80,  80,  87,  82,  76,
          0,  68,  34,  47,  30,  57,  13,  98,  90,  20,  49, 100,  75,  41,
         88,  70,  77,  17,  39,  15,  52,  41], device='cuda:0')
tensor([[-0.2161, -0.4895, -1.1279,  ...,  0.6489,  0.6367,  0.4197],
        [ 0.3811, -0.4863, -0.6484,  ...,  0.4492,  1.6650, -1.0420],
        [-0.1963,  2.4395,  0.7432,  ...,  0.1888, -0.03

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 49,  64,  84,  54,  59,  80,  88,  37,  84,  96,   5,  98,  73,   5,
         61,  15,  67,   9,  57,  49,  44,  75,  96,  76,   7,  49, 100,  88,
         50,  71,  43,  31, 101,   9,  97,   6,  21,  67,  22,   9,  48,  97,
         84,  49,   6,  84,  81,  90,  39,  18], device='cuda:0')
tensor([[-0.0643, -1.5615, -0.1366,  ...,  1.4131,  1.2031,  0.6177],
        [-0.8052,  0.0793,  1.6699,  ..., -0.4888,  0.0701,  0.4978],
        [-0.1576,  0.6416, -0.2200,  ..., -0.0571,  1.4746,  0.1292],
        ...,
        [-0.2006, -1.0664, -0.1616,  ...,  0.8267,  0.3628, -0.3408],
        [ 0.3596, -0.5566, -0.3938,  ...,  0.9561,  1.4824, -0.5347],
        [-0.2059,  0.3547,  0.3569,  ...,  0.6895,  0.5938, -0.0073]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([84, 47, 17, 23, 29, 37, 17, 49, 64, 96, 99, 90,  8, 56, 11, 83, 70, 64,
        72, 77, 94, 61, 40, 42, 94, 63, 82, 12

tensor([[ 0.0679, -0.8579,  0.3137,  ...,  1.4209,  0.1198,  0.0153],
        [ 0.1274, -0.7266, -0.3257,  ...,  0.6050,  0.2690,  1.3330],
        [-0.6401,  0.4185, -0.8423,  ...,  0.8467, -0.5459,  0.0357],
        ...,
        [-0.7188, -0.8208, -0.5962,  ...,  0.3267, -1.1885,  0.5601],
        [ 0.5269, -0.6709, -0.3318,  ...,  0.4517,  0.7329, -1.4268],
        [-0.4102, -0.6489, -0.3276,  ...,  0.2915, -0.8237,  0.9307]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([90, 35, 96, 97, 50, 92, 11, 90, 92,  4, 56, 82, 73, 20, 82, 15, 40, 89,
        74, 82, 26, 27, 64, 37, 42, 95, 55, 75,  8,  1, 81, 96, 98, 73, 81,  6,
        49, 73, 47, 73, 77, 77, 78, 53, 51, 38,  4, 78, 64, 78],
       device='cuda:0')
tensor([[-0.6919, -1.2080, -0.7725,  ..., -0.8472,  0.4587,  0.8125],
        [-0.8320,  0.7129,  1.6064,  ..., -0.1458, -0.2659,  0.2184],
        [ 0.1338, -0.0105, -0.6987,  ...,  0.7207,  1.5273, -0.3062],
        ...,
        [ 0.0351, -0.498

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 88,  71,  76,  30, 100,  83,  84,  91,  33, 100,  97,  62,  94,  25,
         20,  86,  75,  71,  36,  18,  92,  43,  58,  49,  24,  65,  88,   5,
         74, 101,  49,  94,  23,  36,  92,  94,  11,  31,   1,  84,  52,  48,
         39,  23,  82,  38,  97,  27,  23,  24], device='cuda:0')
tensor([[-0.1270, -1.1250, -0.2964,  ...,  1.1494,  0.1318, -0.0107],
        [-0.3020, -0.8584,  0.0625,  ...,  0.3198,  0.1278, -0.6387],
        [-0.5601, -0.6255, -0.2369,  ...,  1.1475,  0.4424,  0.9907],
        ...,
        [ 0.7690, -0.2961, -0.8652,  ...,  0.4819,  1.7148, -1.1602],
        [-0.4614, -0.7314,  0.7793,  ...,  0.3335, -0.9438,  0.6606],
        [-0.6694, -0.6616, -0.7578,  ...,  0.4729,  0.1444,  1.1924]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([90, 43, 81, 77, 53, 97, 16, 55, 77, 90, 86, 72, 87, 62, 98, 81,  9, 15,
        17, 65, 76, 74,  8, 59, 77, 75, 15, 97

tensor([[-0.3877, -0.9985, -0.2206,  ...,  0.5249,  0.8315,  0.7070],
        [-0.2194,  0.5269,  2.2227,  ...,  0.2502, -0.3799,  0.4050],
        [-0.0821, -1.4365, -0.3550,  ...,  0.5444,  1.7871,  0.8027],
        ...,
        [-0.0748, -1.0645, -0.7261,  ...,  0.6333,  0.2842, -0.0943],
        [ 0.1562, -0.5781, -0.3542,  ...,  0.4597,  0.5015, -0.7246],
        [ 0.4946, -0.5425,  0.0231,  ...,  0.7212,  1.8838, -0.8774]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 54,   2,  84,  74,  33,   5,   8,  76,  63,  49,  60,  49,  54,  49,
         54,  82,  75,  73,  41, 100,   6,  88,  82,  26,  27,  82,  40,  12,
          6,  56,  84,  56,  30,  45,  90,  38,  89,  43,  11,   9,  74, 100,
         59,   7,  38,  76,  30,  98,  36,  76], device='cuda:0')
tensor([[ 1.1455, -0.5142, -0.2373,  ...,  0.9570,  2.3203, -1.3711],
        [-0.4993, -0.6279,  0.6216,  ..., -0.0129,  0.7100, -0.5400],
        [-0.1683,  0.0566, -0.4739,  ..., -0.0483,  0.98

tensor([[-0.3191, -0.7217, -0.6812,  ...,  0.1543,  0.9541, -0.4526],
        [-0.6655,  0.8208, -0.8271,  ...,  1.3623,  0.3435, -0.3684],
        [-0.0812, -0.5034, -0.8306,  ...,  0.1561,  3.3594, -0.5474],
        ...,
        [ 0.5029, -0.0857,  1.0342,  ..., -0.1442,  0.3499, -0.1071],
        [ 0.2415, -1.1260, -0.9082,  ...,  0.9131,  1.2129, -0.7935],
        [-0.2131,  1.7070, -0.3059,  ...,  0.1554, -0.0623, -0.4453]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 49,  96, 100,  39,  98, 101,  51,  51,  97,  80,  25,  49,  41,   5,
         45,  41,  88,  11,  11,  56,  77,  91,  58,  89,  87,  90,  10,  74,
         33,  53,  41,  92,  20,  80,  32,  79,  31,  28,  26,   9,  94,  54,
         19,  47,   4,  74,  74,  30,  49,  34], device='cuda:0')
tensor([[ 0.7734,  0.2864, -0.4414,  ...,  0.4363,  2.1914, -0.8721],
        [-0.3396,  0.3174, -0.0880,  ..., -0.3142,  1.0186,  0.3831],
        [-0.1192,  0.5854,  0.6167,  ..., -0.3965, -0.15

       grad_fn=<AddmmBackward>) tensor([ 72,  59, 101,  49,  78,  93,  75,  68,  54,  85,  17,  62,  75,  59,
         55,  26,  57,  90,  85,  57,  60,  74,  40,  59,  87,  97,  96,  89,
         79,  64,  48,  43,  99,  49,  97,   6,  55,  38,  59,  57,  90,  35,
         43,  10,  49,  82,  24,  24,  52,  49], device='cuda:0')
tensor([[-0.1265,  1.6807,  0.3835,  ...,  0.6763, -0.3230, -0.6523],
        [-0.4702, -0.6074,  0.2096,  ...,  0.1086,  0.7334, -0.3606],
        [-0.5039,  0.8047,  0.4790,  ..., -0.0426,  0.5747,  0.6558],
        ...,
        [-0.6235, -1.0195, -0.0702,  ...,  0.3137, -0.1489,  0.9683],
        [ 0.6484, -0.3762,  0.2937,  ...,  0.3271,  1.0020, -1.1689],
        [-0.5938, -1.1914, -0.7524,  ...,  1.3320,  0.4785,  1.3555]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([34, 43, 29, 14, 23, 11, 48, 49, 30, 75, 82, 98, 82, 57, 90,  9, 80, 33,
        81, 66, 33, 71, 21, 49, 91,  3, 79, 26, 86, 23, 89, 71, 98, 42, 26, 84,
    

tensor([[-0.1371, -0.1105,  0.1920,  ..., -0.1105, -0.0479, -0.4463],
        [-0.3379, -1.0137, -0.6050,  ...,  1.3867,  0.7192,  0.5645],
        [ 0.0173, -0.9556, -1.0869,  ...,  0.4795, -0.0671, -0.2681],
        ...,
        [ 0.1042,  0.0525, -0.0468,  ...,  1.7031, -0.3796, -0.0886],
        [-0.5234, -1.0693,  0.5127,  ...,  0.5195,  0.5952,  1.2783],
        [ 0.2832, -0.4424, -0.6182,  ...,  0.4822,  1.9570, -0.6406]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 66,  97,  85,   4,  90,  17,  49,  99,  19,  73,  71, 100,  38,   1,
         82,  96,   8,  49,  48,  51,  57,   6,   8,  18,   1,  91,  83,  80,
         73,  77,  43,  12,  85,  61,  21,  10,  76,  54,  35,   9,  67,  46,
         14,  92,  46,  90,  18,  37,  40,  67], device='cuda:0')
tensor([[-0.0632, -1.4014, -0.5254,  ...,  1.2588,  1.5010,  0.6660],
        [-0.0653, -1.0439, -0.5493,  ...,  1.0830,  0.0071, -0.4304],
        [-0.5674, -1.2295, -0.1517,  ...,  1.3711,  0.46

tensor([[-1.7822e-01, -9.2480e-01, -5.5957e-01,  ...,  3.8989e-01,
          1.8115e+00,  1.4697e-01],
        [-4.4727e-01, -4.9048e-01, -1.7227e-02,  ..., -1.2636e-05,
         -4.8071e-01,  1.0527e+00],
        [ 4.3726e-01, -4.5013e-02, -5.2637e-01,  ...,  6.5479e-01,
          1.1221e+00, -3.9648e-01],
        ...,
        [-1.0284e-01, -5.4932e-01, -3.7549e-01,  ...,  4.8364e-01,
          9.6436e-03, -6.3477e-01],
        [ 6.0400e-01, -1.0254e-01,  2.6535e-02,  ...,  8.6328e-01,
         -6.2744e-01, -1.2842e+00],
        [-1.2164e-01, -7.7783e-01, -6.6833e-02,  ...,  1.6826e+00,
          1.1650e+00, -1.1774e-01]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([97, 73, 75, 91,  0, 41, 43, 48,  5, 34, 78, 81, 26, 56, 92, 64, 15, 15,
        93, 82, 64, 95, 82, 71, 37, 24, 66, 31, 43, 59, 38, 54, 32, 51, 77, 45,
        12, 25, 59, 13, 49,  1,  9,  2, 16, 12, 19, 51, 61, 90],
       device='cuda:0')
tensor([[ 0.3289,  2.0137,  0.3633,  ...,  0.0509

tensor([[ 0.0969, -1.2041, -0.2773,  ...,  0.3098,  1.3867,  0.7412],
        [ 0.4478, -0.1453,  0.4360,  ...,  1.0674,  1.5293, -0.4167],
        [-0.0915,  1.5527,  0.5518,  ..., -0.2351,  0.1727,  0.5112],
        ...,
        [ 0.3032, -0.7720, -0.5356,  ...,  0.6411,  2.2305,  0.1466],
        [-0.0447, -0.4570, -1.0830,  ...,  0.8770,  1.1240, -0.3574],
        [ 0.4453, -1.3691, -0.2433,  ...,  1.7617,  1.3291, -0.5254]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([72, 49, 29, 87, 32, 48, 71, 31,  6, 81, 74, 56, 43, 48, 74, 48, 97, 51,
        82, 73, 32,  6, 55, 14, 81, 66, 97, 60, 90, 54, 76, 57, 12, 37, 22, 45,
        40, 59, 97, 19, 82, 54, 49, 42, 37, 67, 43, 72, 75, 90],
       device='cuda:0')
tensor([[ 0.3032,  1.5156,  0.4680,  ..., -0.4067,  0.4229, -0.7207],
        [-0.5020,  0.2534,  2.7500,  ..., -0.3999, -1.1328,  0.7793],
        [-0.4192, -0.4424,  0.3235,  ...,  0.3142,  0.3081, -0.9287],
        ...,
        [-1.0195, -1.210

tensor([[ 0.1907, -0.4260, -0.6250,  ...,  0.2448, -0.2285, -0.0151],
        [-0.6221, -0.6084,  0.1224,  ...,  0.0211, -0.6060,  1.0654],
        [-0.2561,  0.6167,  0.1074,  ...,  0.8896, -0.6475,  0.2295],
        ...,
        [-0.5034,  0.5015,  1.5488,  ..., -0.5601, -0.7305,  0.7012],
        [-0.2102, -1.3477,  0.1887,  ...,  1.0137,  0.0856,  1.5303],
        [-0.1393,  0.2986, -0.4856,  ...,  0.1124,  1.6660, -0.8130]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 89,  73,  18,  50,  82,  64,  93,  20,  62,  33,  46,  43,  37,  73,
         89,  29,  86,  38,  49,  69,  49,  13,  43,  90,   6,   5,  61,  75,
         13,   4,  85,  49, 101,  78,  90,  15,   4,  51,  49,  66,  21,  31,
         76,  80,  56,  34,  98,  38,  40,  91], device='cuda:0')
tensor([[ 0.0858, -0.1309, -0.1395,  ...,  1.1504,  0.1647, -0.0764],
        [-0.6245,  0.8564,  1.7754,  ...,  0.1595, -0.0449,  0.2261],
        [-0.3474,  3.4082,  0.1521,  ...,  0.1619,  0.34

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 82,  66,  37,  77,  78,  54,  20,  73,  90,  49,  91,  18,  33,  54,
         43,  39,  61,  32,  38,  80,  60,  96,  72,  24,  19,  82, 101,  97,
          3,  16,  77,  14,  74,  71,  54,  83,   6,  60,  89,  83,  23,  83,
         57,  78,  81,   6,  44,  31,  40,  22], device='cuda:0')
tensor([[-0.4155,  0.5488, -0.4771,  ...,  0.0323,  1.6582, -1.0586],
        [-0.5679, -0.3479,  0.9756,  ...,  0.5020, -1.1670,  0.3049],
        [ 0.0487, -0.0414,  0.7339,  ..., -0.1722,  0.3384,  0.2252],
        ...,
        [ 0.0435, -1.4805, -0.6567,  ...,  0.9927,  0.8896,  1.1865],
        [-0.7856,  0.1855,  0.3225,  ...,  0.2274,  0.6162,  0.3423],
        [-0.0208,  0.0618,  0.0712,  ...,  1.1299, -0.3765, -1.0869]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([91, 59, 26, 82,  5, 31, 75,  1, 74, 46, 89, 90, 51, 48, 57, 41, 46, 51,
        98, 54, 49, 84, 81, 70, 27, 87, 77, 90

tensor([[-0.3140, -0.7905, -0.6240,  ..., -0.1105,  1.3135,  0.9844],
        [-0.6533, -0.8672, -0.2927,  ...,  0.8340,  1.0693,  0.1332],
        [-0.1094, -0.9321, -0.3486,  ...,  0.4067,  1.4980, -0.9824],
        ...,
        [ 0.1273,  2.9941,  1.0098,  ...,  0.4326, -0.0859, -1.2217],
        [-0.2208, -1.5859, -0.1008,  ...,  1.7461,  0.7021,  0.5654],
        [-0.5581, -0.6479,  0.5532,  ...,  0.1653, -0.1715,  0.3091]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 32,  92,   5,  57,  80,   5,  43,  86,   9,  84,  96,  78,  53,  64,
         22,  57,  18,  28,  71,  73,  77,  15,  86,  99,  84,  18,  57,  82,
         51,  43,  97,  81,  53,   2, 101,   9,  15,  75,  59,  18,  24,  54,
         88,  51,  80,  89,  30,   8,  81,  74], device='cuda:0')
tensor([[ 0.3687, -0.3159, -0.3057,  ...,  1.5713,  1.0713, -1.2520],
        [ 0.5396, -0.9321,  0.0973,  ...,  0.2646,  1.1592, -0.9380],
        [-0.0974, -0.4592, -0.8188,  ...,  0.2664,  2.28

tensor([[-0.7295,  0.8638,  1.3086,  ..., -0.4575,  0.2452, -0.2607],
        [ 0.0794, -1.2646,  0.2189,  ...,  0.7124,  0.1005,  0.4133],
        [-0.7158, -0.1555, -0.1285,  ..., -0.2114, -0.0873,  1.3867],
        ...,
        [ 0.2357, -0.9771,  0.2229,  ...,  1.4961,  1.5410, -0.7251],
        [ 0.3970, -0.6079, -0.8115,  ...,  0.5806,  1.0586, -0.5552],
        [-0.3066,  1.1748,  0.6597,  ...,  0.3813,  0.3442,  1.4697]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([62, 12, 73, 73, 90, 60, 15, 95, 10, 57, 43, 74, 55, 60,  3, 38, 56, 82,
         5, 59,  9, 95, 40, 40, 28, 76, 82, 76, 49, 84, 80, 90, 24, 32, 14, 51,
        33, 49, 45, 76, 83, 90, 92, 72, 78, 20, 15, 90, 36, 77],
       device='cuda:0')
tensor([[-0.3145, -0.8315, -0.7783,  ...,  0.5962,  0.5132,  0.1086],
        [-0.6851,  0.1165,  1.6641,  ..., -0.2222,  0.2654,  0.5811],
        [-0.4902,  0.2690,  0.8652,  ...,  0.3044, -0.2817,  0.1545],
        ...,
        [-1.2334,  0.928

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([81, 31], device='cuda:0')
Training: Epoch 28 || Loss:   1.547 || Accuracy:  81.97%
tensor([[-0.5415,  1.6426,  1.2578,  ...,  0.3640, -0.0963,  0.8330],
        [-0.2939, -1.1602, -0.3623,  ...,  1.1055,  1.1611,  0.4878],
        [ 0.6660, -0.3162,  0.1277,  ...,  1.0791,  0.1700, -1.3135],
        ...,
        [ 0.5747,  1.1943,  0.0597,  ..., -0.4294,  0.9932, -0.7959],
        [ 0.7710, -1.0098,  0.0301,  ...,  0.6772,  1.3525, -0.4216],
        [-0.2366, -0.8320, -0.0086,  ...,  0.8447,  0.8042,  0.1810]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.2069,  0.7295,  0.5771,

tensor([[ 0.0473,  1.6055,  1.1865,  ..., -0.0349,  0.0204, -0.4097],
        [ 0.5781, -0.5801, -0.8892,  ...,  0.8218,  1.3486, -0.5151],
        [ 0.7437, -0.2231, -0.2666,  ...,  1.2158, -0.2617, -1.3291],
        ...,
        [-0.3230, -0.3213, -0.3093,  ...,  0.3477, -0.0798,  0.9106],
        [ 0.2028,  1.0508, -0.2778,  ...,  0.2720,  0.0294, -0.0874],
        [-0.5303,  4.0469,  0.4900,  ...,  0.3037, -0.1495,  0.2318]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 48,  49,  61,  50,  35,  15,  46,   5,  44,  51,  90,  71,  27,  39,
         37,  18,  60,  23,  37,  71,  83,  81,  34,  35,  82,  25,  99,  97,
         85,  89,  51,  82,   7,  48,  74,  19,  64, 101,  24,  86,  14,  34,
         93,  72,  72,  86,  88,  78,  66,   1], device='cuda:0')
tensor([[-0.6479,  1.4570,  0.8687,  ..., -0.4165, -0.8027,  0.4597],
        [-0.1454,  1.5771,  0.3706,  ...,  0.9946,  0.7236,  0.4011],
        [ 0.9634, -0.8359, -0.3064,  ...,  0.6250,  1.48

tensor([[ 0.1489,  0.5503, -0.7725,  ...,  0.3450,  0.3989, -0.8164],
        [-0.4011, -0.5659, -0.6685,  ...,  0.0035,  0.6582,  0.0106],
        [ 0.3794, -0.0382,  0.0787,  ...,  0.6938,  1.3047,  0.5029],
        ...,
        [-0.2632, -0.8726, -1.0830,  ...,  1.4814,  1.9375,  0.2118],
        [-0.2201, -0.8359,  0.6865,  ...,  0.8330,  0.6641,  0.1667],
        [-0.0598, -0.0696,  1.4316,  ...,  0.4819, -0.2374,  0.8979]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([96, 27, 83, 90,  5, 98, 38, 33, 33, 64, 23, 63, 64, 84, 38, 73,  5, 38,
        72, 74,  6, 72, 98, 50, 30, 71, 81, 77, 60, 92, 57, 26, 43, 63, 65, 68,
        92, 73, 32,  4, 96, 73, 73, 21, 75, 46, 90, 75, 92, 38],
       device='cuda:0')
tensor([[-0.5132, -0.3975,  0.6167,  ...,  0.2778,  0.4175,  0.8159],
        [-0.4961,  0.5107,  1.7354,  ..., -0.3713, -0.3262,  0.2847],
        [ 0.6162,  0.3267, -0.4341,  ...,  0.9595,  1.8682, -1.3770],
        ...,
        [ 0.7217, -0.514

tensor([[-0.5449,  1.3252, -0.2961,  ...,  0.6582,  0.0743, -0.2532],
        [-0.4666, -0.6343,  0.8667,  ...,  0.1079,  0.1088, -0.4429],
        [ 0.1873, -0.0349, -0.1741,  ...,  0.8564,  2.0234, -0.3381],
        ...,
        [ 0.1860,  0.5830, -0.0486,  ...,  1.5146, -0.1816, -0.2086],
        [-0.6694, -1.0195, -0.8809,  ...,  0.6597,  0.0521,  0.2930],
        [-0.4917, -0.6055, -0.1204,  ...,  0.4561,  0.0274,  0.8271]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  43,  49,  17,  22,  21, 101,  78,  85,  74,  38,  98,  54,  77,
         59,  15,  41,  31,  78,  20,  73,  38,  11,  92,  49,  40,  73,  19,
         49,  79,  85, 101,  47,  88,  37,  90,  53,  41,  66,  90,  89,  92,
         49,  99,  90,  88,  25,  96,  15,  78], device='cuda:0')
tensor([[-0.1493,  2.1133,  0.7964,  ...,  0.1447,  0.3171,  0.0808],
        [ 0.0659, -1.0225, -0.3726,  ...,  1.2910,  0.5322,  0.0561],
        [-0.1007,  1.1328,  0.6016,  ...,  0.7222,  0.23

tensor([[-0.4937,  0.3855,  1.4951,  ..., -0.2825,  0.2047, -0.1788],
        [ 0.4712, -0.3987, -0.5933,  ...,  2.2832, -0.7695, -0.5249],
        [-0.2761, -0.0674, -0.9780,  ..., -0.2274,  0.6084,  0.1676],
        ...,
        [-0.1758,  0.6514,  1.7148,  ..., -0.1549, -0.6992,  0.0300],
        [-0.0063,  0.4490, -0.4983,  ...,  0.2260,  1.6465, -1.6113],
        [ 0.0817,  0.4187,  1.1836,  ...,  0.4910, -0.7852, -0.1560]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 44,  99,  33,  52,  56,  82,  72,  65,  93,  78,   5,  27,  11,  80,
         33,  50,   6,  99,  34,  56,  54,  10,  14,  30,  68,  68,  85,  82,
          2,  47, 100,  59,  28,   7,  41,  90,   3,   9,  54,  53,  96,  90,
         99,  41,  35,   0, 100,   2,  91,  38], device='cuda:0')
tensor([[-0.1964, -0.7207,  0.5269,  ...,  0.5273, -0.7114,  0.9546],
        [ 0.0659,  1.2432,  1.3936,  ...,  0.7710,  0.3872, -0.8574],
        [-0.5308, -0.7051,  0.1260,  ...,  0.3621, -0.57

tensor([[ 0.5723,  0.0180, -0.4224,  ...,  0.1385,  1.7734, -0.2264],
        [-1.0566, -0.7324,  0.0780,  ..., -0.2338,  0.9795,  2.9844],
        [ 0.4551, -0.3293, -0.2014,  ...,  1.0322, -0.4268, -1.3369],
        ...,
        [-0.7241, -0.9053,  0.0158,  ...,  0.4824,  0.5586,  0.0838],
        [ 0.2781, -0.2976, -0.1562,  ..., -0.1462,  1.7832, -0.8042],
        [-0.6548, -0.3091, -0.0059,  ..., -0.3125, -1.0889,  1.0840]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 13,  80,  61, 100,  40,  33,  85,  49,  93,  75,  93,  62,  17,  57,
         78,  78,  77,  74,  24,  80,  33,  25,  29,  12,  78,   1,  85,  55,
         59,  90,  57,  51,  70,   0,  96,  15,  25,  18,  51,  47,  89,  82,
         38,  28,  90,  47,  53,  43,  28,  73], device='cuda:0')
tensor([[-0.5137, -0.4819,  0.1306,  ..., -0.0172, -0.1289,  0.0344],
        [-0.2749, -0.8018,  0.1760,  ...,  0.7607, -0.1660,  0.4707],
        [-0.3582,  1.4902,  1.2842,  ..., -0.1799,  0.17

tensor([[ 0.0963,  0.3596, -0.8223,  ...,  2.4648, -0.0875, -0.9966],
        [ 0.6489, -0.8325, -0.7759,  ...,  0.7080,  2.7773, -0.8276],
        [-0.4824, -1.0557, -0.6274,  ...,  0.3772,  0.2039,  1.2988],
        ...,
        [ 0.2026, -0.6865, -0.4387,  ...,  0.8682,  1.8438, -0.1390],
        [-1.0684, -0.3186, -0.2491,  ...,  0.3904,  0.4436,  2.3691],
        [ 0.1575, -0.9302, -0.1175,  ...,  0.7964,  0.6558,  0.5669]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  53,  69,  81,  61,  18,  78,  33,  39,  43,  65,  50,  90,  86,
         55,  15,  83,  32,  14,  77,  33,  82,  62,  24,  25,  84,  51,   2,
         56,  73,  89, 101,  14,  97,   9,  26,  49,  48,  85,  97,  49,  89,
         38,  77,  43,  64,  24,  72,  19,  72], device='cuda:0')
tensor([[-0.3494, -0.0055, -0.3770,  ...,  0.8364,  0.5859, -1.2158],
        [ 0.4800, -0.5547,  0.7319,  ...,  0.3237,  0.7305, -0.8955],
        [-0.0227, -1.1992,  0.0558,  ...,  1.1191,  1.84

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([79, 60, 84, 51, 43, 92, 80, 13, 94, 96, 54, 73, 89, 54, 48, 38, 66, 80,
        77, 67, 94, 60, 36, 15, 49, 49, 10, 51, 90, 52, 85, 45, 84, 79, 83, 58,
        59, 39, 40, 49, 58, 70, 24, 23, 51, 97, 47, 14, 77, 77],
       device='cuda:0')
tensor([[-0.0737, -0.3242,  0.3047,  ...,  1.0391, -0.2891, -0.3364],
        [-0.1907, -1.4180, -0.5688,  ...,  1.5869,  1.1484,  1.3877],
        [-0.5283, -0.5840, -0.0880,  ..., -0.2861, -0.4070,  1.0781],
        ...,
        [ 0.3027, -0.6348, -0.0814,  ...,  0.8672,  1.8994, -0.5225],
        [-0.0802, -0.2612,  0.3186,  ...,  0.3020, -0.2842, -0.3682],
        [-0.7085, -0.4946, -0.2849,  ...,  0.1444, -0.5464,  1.4170]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([58, 98, 73, 43, 96, 92, 44, 42, 54, 45, 59, 41, 63, 49, 90, 29, 36, 47,
        43, 96, 77, 74, 82, 49, 85, 50, 38, 98, 49, 78, 29, 65, 49, 41, 44,  0,
        58, 85, 5

tensor([[ 0.4336, -1.0176, -0.8555,  ...,  0.0365,  1.4697,  0.0588],
        [-0.0210,  0.2363, -0.1301,  ...,  0.2296,  0.6743,  0.1499],
        [ 0.3499,  0.5259, -0.3467,  ...,  0.0700,  3.2754, -0.3232],
        ...,
        [ 0.1350,  1.1943, -0.1064,  ..., -0.1415,  1.1152, -0.9146],
        [-0.1277, -0.1776, -0.3870,  ...,  0.3298,  0.2054,  1.0820],
        [ 0.2373, -0.3171, -0.1641,  ...,  0.4231,  0.6152,  0.8032]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 72,  23,  67,  83,  47,  91,  19,  55,  78,  77,   7,  78,  64,  76,
         73,  55,   8,  80,  17,  82,  30,  92,  43,  78,  24,  44,  84,  44,
        100,  22,  57,  87,  76,  70,  71,  49,  17,  93,   6,  90,  71,  93,
         15,  86,  18,  96,  56,  17,  35,  33], device='cuda:0')
tensor([[ 0.6890, -0.6104, -0.0477,  ...,  0.1678,  0.2236,  0.4136],
        [ 0.0951,  0.0038, -0.1126,  ..., -0.5659,  0.6606,  0.1422],
        [ 0.3306, -0.4160, -1.0459,  ...,  1.3799,  2.30

tensor([[-0.5688,  0.4119,  0.3354,  ..., -0.1228,  0.5825,  2.0449],
        [-0.1531, -1.4111, -0.4829,  ...,  1.4473,  0.9897,  0.9072],
        [-0.4368, -0.3435, -0.9487,  ...,  0.2451, -0.5630,  0.2852],
        ...,
        [ 0.6411, -0.2299, -0.4360,  ...,  0.8394, -0.0219, -0.6074],
        [ 0.0189, -0.5205, -1.0664,  ...,  0.9834,  1.2998,  0.6680],
        [-0.4502,  0.8965,  1.9619,  ..., -0.2426,  0.1671,  0.0244]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([95, 81, 89, 66, 43, 20, 43, 85, 68, 63,  3, 32, 49, 37, 77, 97, 82, 53,
        82, 83, 71, 62, 90, 88, 90, 78, 78, 60, 32, 17, 90, 10, 51, 82, 96, 78,
        82, 93, 73, 51, 47, 65, 82, 63, 64, 41, 80, 68, 75, 52],
       device='cuda:0')
tensor([[ 0.3311, -0.5254,  0.0495,  ...,  1.3164,  0.4663, -0.6357],
        [ 0.3306, -0.3450, -0.3806,  ...,  1.5449,  0.1726, -0.8135],
        [ 0.0067,  0.9092, -0.5542,  ...,  0.0528,  0.9370,  0.2717],
        ...,
        [-0.3345, -0.362

tensor([[ 0.1126,  1.0098,  1.3770,  ...,  0.4202, -0.2893, -0.9609],
        [ 0.0219, -0.7363, -0.1384,  ...,  0.6675,  1.7588, -0.3298],
        [ 0.3701, -1.3389,  0.1274,  ...,  1.1172,  2.2305,  0.2603],
        ...,
        [ 0.8193, -0.1503, -0.5791,  ..., -0.0264,  3.7773, -1.1465],
        [ 0.5229, -0.9180, -0.3506,  ...,  0.5293,  0.2615, -1.0928],
        [ 0.0260, -0.1192, -0.8003,  ..., -0.1041,  0.9258, -0.7339]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([  6,  90,  84,  31,   1,  73,  82,  65,   6,  88,  31,  97,  81,  50,
         78,  84,  56,  74,  90,  23,  83,  89,  27,  83,  82,  24,  12,  84,
         56,  10,  23, 100,  96,   9,  70,  65,  72,  19,  40,  69,  49,  49,
         67,  59,  82,  30,  65, 100,  50,  28], device='cuda:0')
tensor([[ 0.4373, -0.0900, -0.0936,  ...,  0.2086,  0.3430,  0.2600],
        [ 0.0817,  1.6914,  1.3096,  ..., -0.0124, -0.3997, -0.3799],
        [ 0.4583, -0.6201,  0.4844,  ...,  0.6831, -0.59

tensor([[-0.2007,  1.1895,  0.3550,  ...,  0.6890,  0.9541,  1.1445],
        [ 0.6201, -0.5771, -0.2998,  ...,  0.1696,  1.5010, -0.7939],
        [-0.7227, -1.0859, -0.7065,  ..., -0.1735,  0.2910,  0.7476],
        ...,
        [ 0.3835,  0.1163, -0.0141,  ...,  0.9404,  0.5767, -1.3574],
        [-0.2932, -1.0234, -0.9307,  ...,  1.2119,  2.3047, -0.7935],
        [-0.5054,  0.5967,  1.3545,  ...,  0.4719, -0.8350,  0.1638]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([77, 50, 69, 76, 52, 33, 75, 46, 43, 21,  1, 79, 54, 94, 49, 33, 93, 86,
        23, 60, 73, 85, 74, 33, 76,  6, 76, 78, 84, 65, 73, 96, 33, 21, 10, 39,
        77, 91, 48, 74, 54, 39, 84, 77, 96, 96, 43, 70, 97, 57],
       device='cuda:0')
tensor([[ 0.2267,  0.3411, -0.7979,  ...,  0.8179, -0.5845, -1.0020],
        [ 0.3103, -0.4182, -0.3735,  ...,  0.2251,  0.2396, -0.5791],
        [-0.3479, -0.6362,  0.0643,  ...,  0.5479, -0.7153,  0.4949],
        ...,
        [-0.1455, -0.127

tensor([[ 0.1084,  0.1813, -0.1725,  ...,  1.0449, -0.1328, -0.7939],
        [-0.4277,  1.5928,  1.2930,  ...,  0.3408, -0.0911,  0.7017],
        [-0.5059, -0.1929,  1.2676,  ...,  0.2744, -1.0264,  0.5415],
        ...,
        [-0.2981, -1.2490, -0.6406,  ...,  0.2534, -0.1081,  0.2957],
        [-0.8052,  0.4795,  1.6641,  ..., -0.1425, -0.3669,  0.5552],
        [-0.5449, -0.7358, -0.1840,  ...,  1.0225,  0.4058,  0.9561]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  24,  59,  77, 101,   8,  80,  82,  80,  83,  89,  11,  46,  76,
         96,  77,  97,  78,  68,  77,  54,  55,  76,  80,  80,  87,  82,  76,
          0,  68,  34,  47,  30,  57,  13,  98,  90,  20,  49, 100,  75,  41,
         88,  70,  77,  17,  39,  15,  52,  41], device='cuda:0')
tensor([[-0.0837, -0.4456, -1.1494,  ...,  0.7822,  0.7183,  0.4880],
        [ 0.6055, -0.4385, -0.6230,  ...,  0.5825,  1.8408, -1.1367],
        [-0.0520,  2.7910,  0.9229,  ...,  0.2585, -0.05

0it [00:00, ?it/s]

Found best model
Training: Epoch 30 || Loss:   1.466 || Accuracy:  83.70%
tensor([[-0.5020,  1.7549,  1.3926,  ...,  0.4224, -0.1054,  0.8940],
        [-0.2362, -1.1797, -0.3323,  ...,  1.2129,  1.2324,  0.5283],
        [ 0.7920, -0.2856,  0.1781,  ...,  1.1807,  0.1981, -1.3584],
        ...,
        [ 0.6929,  1.2881,  0.1268,  ..., -0.4294,  1.0615, -0.8081],
        [ 0.8979, -1.0234,  0.0862,  ...,  0.7480,  1.4326, -0.4175],
        [-0.1692, -0.8457,  0.0370,  ...,  0.9307,  0.8501,  0.2057]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 94,  84,  45,  81,  28,  85,  57,  77,  98,  11,  27,  55,  14,  99,
         84,  13,  11,  67,  89,  26,  52,   1, 101,  91,  41,  33,  27,  77,
         77,  84,  98,  23,  78,  10,  84,  35,  82,  84,  80,  51,  93,  41,
         54,  96,  82,  52,  53,  17,  49,  92], device='cuda:0')
tensor([[-0.1395,  0.8047,  0.6699,  ...,  0.1576,  0.9971,  1.1523],
        [ 0.7407, -0.3406,  0.5405,  ...,  1.1729,  

tensor([[-0.4138, -0.2710, -0.1287,  ..., -0.0804, -0.6143,  0.5361],
        [-0.2351,  0.1842,  0.0985,  ...,  1.1416, -1.0820,  0.0260],
        [ 0.0114, -0.0749, -0.5562,  ...,  0.3748, -0.0910,  0.5723],
        ...,
        [ 0.5889,  0.0871, -0.1965,  ...,  0.6772,  1.1602, -1.3877],
        [-0.4058,  0.2827,  1.4082,  ...,  0.7642, -0.9204,  0.1483],
        [-0.0198, -1.1338,  0.6372,  ...,  0.3452,  0.8799,  0.4431]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([78, 18, 35, 99, 96, 28, 82, 53, 40, 77, 49, 43, 64,  7, 23, 64, 86, 84,
        59, 30, 22, 53, 23, 83, 83, 74,  6, 57, 59, 50,  8, 96, 75, 43, 89, 84,
        29, 43, 97,  7, 80, 77, 59, 96, 65, 15, 26, 36, 57, 40],
       device='cuda:0')
tensor([[ 0.0459, -0.1632, -0.1160,  ...,  0.8091,  0.3259, -0.6611],
        [ 0.8076, -0.3992, -0.2068,  ...,  0.2915,  2.0742, -1.0605],
        [-0.5894,  0.1522,  2.0215,  ..., -0.0704, -0.8101,  1.3438],
        ...,
        [-0.1030, -0.952

tensor([[ 0.4844, -1.0234, -0.8545,  ...,  0.0523,  1.5137,  0.0698],
        [ 0.0190,  0.2598, -0.1136,  ...,  0.2422,  0.6938,  0.1631],
        [ 0.3970,  0.5503, -0.3386,  ...,  0.0851,  3.3594, -0.3208],
        ...,
        [ 0.1781,  1.2383, -0.0796,  ..., -0.1367,  1.1494, -0.9287],
        [-0.0954, -0.1595, -0.3645,  ...,  0.3540,  0.2090,  1.1104],
        [ 0.2808, -0.3093, -0.1356,  ...,  0.4465,  0.6338,  0.8301]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 72,  23,  67,  83,  47,  91,  19,  55,  78,  77,   7,  78,  64,  76,
         73,  55,   8,  80,  17,  82,  30,  92,  43,  78,  24,  44,  84,  44,
        100,  22,  57,  87,  76,  70,  71,  49,  17,  93,   6,  90,  71,  93,
         15,  86,  18,  96,  56,  17,  35,  33], device='cuda:0')
tensor([[ 0.7383, -0.6104, -0.0239,  ...,  0.1832,  0.2410,  0.4346],
        [ 0.1345,  0.0262, -0.0862,  ..., -0.5659,  0.6792,  0.1528],
        [ 0.3811, -0.3989, -1.0508,  ...,  1.4316,  2.36

tensor([[-0.5015,  1.4414, -0.2532,  ...,  0.7178,  0.0737, -0.2467],
        [-0.4116, -0.6343,  0.9741,  ...,  0.1561,  0.1201, -0.4395],
        [ 0.2864,  0.0106, -0.1296,  ...,  0.9346,  2.1367, -0.3384],
        ...,
        [ 0.2832,  0.6519,  0.0060,  ...,  1.6299, -0.1875, -0.2140],
        [-0.6299, -1.0127, -0.8662,  ...,  0.7241,  0.0537,  0.3167],
        [-0.4490, -0.5693, -0.0614,  ...,  0.5146,  0.0302,  0.8809]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 96,  43,  49,  17,  22,  21, 101,  78,  85,  74,  38,  98,  54,  77,
         59,  15,  41,  31,  78,  20,  73,  38,  11,  92,  49,  40,  73,  19,
         49,  79,  85, 101,  47,  88,  37,  90,  53,  41,  66,  90,  89,  92,
         49,  99,  90,  88,  25,  96,  15,  78], device='cuda:0')
tensor([[-0.0706,  2.2520,  0.8950,  ...,  0.1709,  0.3289,  0.0965],
        [ 0.1489, -1.0254, -0.3450,  ...,  1.3975,  0.5781,  0.0825],
        [-0.0236,  1.2383,  0.6870,  ...,  0.7915,  0.25

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([76,  6,  3, 78,  6, 91, 22, 81, 98, 57, 38, 24, 35, 24, 85, 37, 41, 79,
        17, 51, 77, 49, 86,  3, 37, 64, 43, 63, 41, 14, 34, 76, 21, 77, 29, 66,
        16, 43, 31, 94, 87, 81, 63, 77, 78, 27, 51, 21, 12, 35],
       device='cuda:0')
tensor([[-0.4878, -0.4653,  0.1455,  ...,  0.2737,  0.5474,  0.5723],
        [-0.4644,  1.6016,  1.1699,  ...,  0.4172, -0.7886,  0.1403],
        [-0.7949, -1.3008, -0.4380,  ...,  1.6221,  0.8213,  1.9160],
        ...,
        [-0.2776, -1.0518, -0.1530,  ...,  0.3154,  1.4062, -0.2421],
        [ 0.5059, -0.2277, -1.0439,  ...,  0.6309,  2.4141, -1.9512],
        [ 0.5688, -0.5596,  0.3674,  ...,  1.0908, -0.1550, -1.2871]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 74,   6,  81, 101,  18,  11,  49,  75,  84,  24,  40,  27,  23,  87,
         87,  27,  74,  74,  34,  85,  97,  43,  98,  38,  54,  90,  79,  33,
         40,  62,  33

tensor([[-0.4299,  0.3232,  1.6025,  ...,  0.0703, -0.1360,  0.5513],
        [ 0.3728, -0.2563, -0.2454,  ...,  2.8496,  0.5791, -0.7393],
        [-0.5146, -0.6943, -0.6172,  ..., -0.0126,  0.0195,  1.1426],
        ...,
        [ 0.1075, -0.9131,  0.4250,  ...,  0.2917, -1.1670,  0.7148],
        [ 0.1582,  1.8477,  1.5000,  ...,  0.0970, -0.0890, -0.7783],
        [ 0.6812,  0.0479,  0.1200,  ...,  0.7817,  2.3750, -1.3115]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 38,  99,  88,  28,  74,  83,  12,  81,  40,  77,  77,  76,  77,  91,
         75,  77,  23,  84,  40,  85,  50,  72, 101,  90,  38,  96,  32,  85,
         17,  49,  44,   1,  49,  49,  77,  28,  47,  46,  77,  50,  50,  40,
         81,  41,  70,  74,   8,  12,   6,  76], device='cuda:0')
tensor([[-0.1769,  1.2363,  0.3882,  ...,  0.7202,  0.9805,  1.1748],
        [ 0.6758, -0.5674, -0.2805,  ...,  0.1879,  1.5410, -0.8003],
        [-0.7075, -1.0908, -0.7031,  ..., -0.1635,  0.29

tensor([[ 0.1814,  0.4185, -0.8232,  ...,  2.6426, -0.0827, -1.0342],
        [ 0.7749, -0.8213, -0.7705,  ...,  0.7832,  2.9141, -0.8442],
        [-0.4392, -1.0693, -0.6196,  ...,  0.4189,  0.2183,  1.3711],
        ...,
        [ 0.3032, -0.6860, -0.4136,  ...,  0.9512,  1.9395, -0.1201],
        [-1.0664, -0.2776, -0.1995,  ...,  0.4478,  0.4436,  2.4902],
        [ 0.2452, -0.9399, -0.0814,  ...,  0.8721,  0.7056,  0.6079]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 99,  53,  69,  81,  61,  18,  78,  33,  39,  43,  65,  50,  90,  86,
         55,  15,  83,  32,  14,  77,  33,  82,  62,  24,  25,  84,  51,   2,
         56,  73,  89, 101,  14,  97,   9,  26,  49,  48,  85,  97,  49,  89,
         38,  77,  43,  64,  24,  72,  19,  72], device='cuda:0')
tensor([[-0.2629,  0.0403, -0.3567,  ...,  0.9253,  0.6123, -1.2539],
        [ 0.6040, -0.5322,  0.8218,  ...,  0.3801,  0.7769, -0.9209],
        [ 0.0568, -1.2051,  0.1011,  ...,  1.2236,  1.94

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([59, 70, 73, 38, 90, 82, 69, 11, 97, 76, 26, 61, 31, 89, 74, 96, 63, 97,
        63, 77, 10,  6, 45, 91, 94, 53, 57, 82, 77, 76, 89, 51, 73, 15, 42, 61,
        17, 87, 74, 89, 13, 59, 63, 43, 59, 90, 31, 43, 49, 77],
       device='cuda:0')
tensor([[ 0.5552, -0.8994, -0.5107,  ...,  0.8213,  0.7510, -0.5508],
        [-0.3245, -1.3076,  0.0531,  ...,  0.1974,  0.9810, -0.0823],
        [-0.5415, -1.2500, -0.3235,  ...,  0.8101,  1.3105,  1.2578],
        ...,
        [ 0.6548, -0.3418,  0.2549,  ...,  1.2393, -0.2852, -1.4746],
        [ 0.3821, -1.3076,  0.1185,  ...,  1.2539,  1.0410,  0.1323],
        [ 0.0123, -0.6782,  0.0536,  ...,  0.4766,  1.5879,  1.6963]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([49,  5,  3, 77, 92, 33, 69, 49, 19, 49, 27, 71, 89, 16, 31, 96, 12, 65,
        66, 40, 62, 80, 86, 23, 96, 45, 41, 40, 90, 77, 88, 31, 95, 50, 80, 97,
        46, 67, 6

tensor([[ 0.7773, -0.9287,  0.3638,  ...,  1.3281,  1.7100,  0.1512],
        [ 0.7612, -0.9214,  0.8174,  ...,  0.6538,  0.5142, -0.1576],
        [ 0.3745,  0.3936,  1.2256,  ...,  0.0123,  0.3462, -0.3513],
        ...,
        [ 0.3110, -0.8047,  0.3560,  ...,  1.1465,  0.0055, -0.2810],
        [-0.2253,  0.3979,  0.2625,  ...,  0.8691,  1.2197,  1.9561],
        [ 0.1666, -0.5469, -0.5654,  ...,  1.1143,  2.2832,  0.3384]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 84,  12,  65,   8,  18,  22,  92,  18,  22,  66,  25,  39,  44,  45,
          7,  65,  51,  51,   9, 100,  89,  74,  98,  11,  48,  78,  89,   8,
         75,  74,  75,  90,  96,  64,  72,   3,  82,  83,  74,  64,   6,  72,
         94,   3,  33,  77,  41,  37,  95,  75], device='cuda:0')
tensor([[ 0.1434, -0.2008,  0.2964,  ...,  0.5225,  1.1436,  0.7056],
        [-0.1232, -0.2993,  0.0999,  ...,  0.3672, -1.1016,  1.6875],
        [-0.1667,  1.3252,  1.7920,  ...,  0.1587,  0.40

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([95, 81, 89, 66, 43, 20, 43, 85, 68, 63,  3, 32, 49, 37, 77, 97, 82, 53,
        82, 83, 71, 62, 90, 88, 90, 78, 78, 60, 32, 17, 90, 10, 51, 82, 96, 78,
        82, 93, 73, 51, 47, 65, 82, 63, 64, 41, 80, 68, 75, 52],
       device='cuda:0')
tensor([[ 0.4353, -0.4985,  0.0942,  ...,  1.4180,  0.5093, -0.6338],
        [ 0.4265, -0.3167, -0.3467,  ...,  1.6641,  0.1865, -0.8320],
        [ 0.0907,  0.9907, -0.5337,  ...,  0.0840,  0.9907,  0.3081],
        ...,
        [-0.2742, -0.3413,  0.6377,  ...,  0.4314,  0.5259,  1.5654],
        [ 0.1699,  0.0274, -0.2961,  ...,  3.6270,  0.2468, -0.5322],
        [ 1.0732, -0.8452, -0.5044,  ...,  0.8579,  2.8027, -1.4209]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 51,  96,  94,  49,  20,  89,  55,  98,  78,  49,  32,  98,  64,  74,
         87,   8,  17,  77,  49,   5,  75,  24,  38,  89,  89,   2,  11,  76,
         42,  88,  64

tensor([[-0.3113, -0.3706, -0.4478,  ..., -0.0997, -0.9604,  1.7207],
        [ 0.1909, -0.2708, -0.3875,  ...,  0.2605, -1.0967, -0.3752],
        [ 0.6372, -0.9883, -0.2399,  ...,  0.3413,  2.7090, -0.2382],
        ...,
        [ 0.7021,  0.4534, -0.9502,  ...,  2.1035,  2.0215, -0.8872],
        [-0.0175, -0.6304, -0.3333,  ...,  2.0977,  0.7920, -0.3115],
        [-0.2152, -0.6660, -1.0518,  ...,  0.0868,  1.0039,  0.0257]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 93,  68,  28,   6,  24,  76,  65,  27,  24,   3,  85,  21,  49,  62,
         53,   3,  62, 100,  48,  78,  77,  36,  79,  51,  94,  27,  54,  51,
         53, 100,  41,  58,  98,  93,  58,  94,   5,  49,  76,  64,   3,  74,
         12,  52,   0,  69,  10,  13,  90,  55], device='cuda:0')
tensor([[ 0.7041, -1.1406, -0.6465,  ...,  0.7822,  3.7891, -0.2639],
        [-0.3713, -0.1685, -1.0898,  ..., -0.0940, -0.1183,  0.3242],
        [-0.2239, -0.9941, -0.5942,  ...,  0.4556, -0.12

tensor([[-0.3823,  0.1853,  1.4580,  ...,  0.6680, -1.1006, -0.0850],
        [-1.0029,  0.0369,  0.6538,  ...,  0.0319,  1.1455,  3.0977],
        [-0.5488, -0.6396,  0.0181,  ...,  1.0361, -0.1216,  1.3955],
        ...,
        [ 0.0609, -0.8179, -0.3044,  ...,  0.5186,  0.4827, -0.0597],
        [ 0.2849, -1.2559, -0.7827,  ..., -0.3701,  2.6172,  0.3313],
        [ 0.5107, -0.3372, -0.3176,  ...,  1.5947, -0.7251, -0.8843]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([57, 19,  3, 77, 37,  1, 34, 58, 74, 48, 64, 38, 74, 80,  6, 33,  3, 69,
        90, 74, 94, 97, 76, 89, 73, 37, 87,  7, 40, 18,  1, 90, 73, 17, 97,  8,
        27, 48,  9, 25, 76, 74, 48, 44, 74, 77, 23,  9, 28, 61],
       device='cuda:0')
tensor([[-0.7256, -0.0787, -0.0936,  ...,  1.2881, -0.2372,  1.0859],
        [-0.4304, -0.9087, -0.2162,  ...,  0.1704,  0.4783,  1.1455],
        [-0.4641,  0.0858, -0.3560,  ...,  0.2312,  1.2822,  1.1982],
        ...,
        [-0.1603, -1.298

tensor([[-0.0166, -0.4216, -1.1592,  ...,  0.8496,  0.7637,  0.5215],
        [ 0.7178, -0.4087, -0.6113,  ...,  0.6499,  1.9277, -1.1777],
        [ 0.0217,  2.9609,  1.0088,  ...,  0.3022, -0.0692, -0.2406],
        ...,
        [-0.0594,  0.5142,  0.1622,  ..., -0.2996,  0.6323,  2.2656],
        [-0.4153, -1.1367, -0.4788,  ..., -0.0923,  0.4175,  0.4255],
        [ 0.7031, -0.2620, -0.3755,  ...,  1.3447, -0.3706, -0.7573]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 75,  87,   8,  80,  56,  59,  72,  41,  80,  16,  85,  41,   5,  76,
         53,  71,  81, 101,  40,  64, 101,  43,  73,  54,  96,  11,  51,  11,
         64,  12,  56,  25,  11,   4,  63,  46,   6,  58,  24,  92,  77,  15,
         95,  90,  84,  66,  35,  95,  89,  68], device='cuda:0')
tensor([[-3.5840e-01, -8.0322e-01, -7.0752e-01,  ...,  5.3906e-01,
          1.7998e+00,  1.1094e+00],
        [ 7.3779e-01, -3.7793e-01,  2.4597e-01,  ...,  9.1125e-02,
          2.7695e+00, -3.9

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([22, 96, 15, 92, 67, 43, 81, 56, 78, 17, 64, 24, 24, 81, 84, 19, 27, 55,
        98, 55, 43, 26, 78,  7, 65, 44, 36, 72, 21, 73, 29, 82,  1, 85, 38,  0,
        14, 33, 74, 19, 12, 90, 18, 59, 77, 60, 94, 89, 49, 80],
       device='cuda:0')
tensor([[-0.2308, -0.8438,  0.1851,  ...,  0.9712,  0.4351, -0.0359],
        [-0.1255,  0.2333,  1.9580,  ..., -0.1156, -0.1487,  0.3220],
        [ 0.4050, -1.2705,  0.0177,  ...,  1.9941,  1.4160, -0.2791],
        ...,
        [ 0.5601, -0.0326, -0.2576,  ..., -0.2162,  2.1914, -1.0625],
        [ 0.4258, -1.2334, -0.2451,  ...,  0.0497,  1.0264, -0.6792],
        [ 0.2600, -0.3823,  0.3955,  ...,  1.0059,  0.8901, -0.2172]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([92, 47, 90, 12, 97, 16, 93, 49, 72, 32,  3, 94, 39, 82, 41, 76, 40, 72,
        59, 64, 49, 12, 90,  7,  8, 51, 52, 76, 75, 25, 73, 83, 44, 38, 82, 75,
        73, 10, 9

       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([82, 52, 23, 42, 49, 26, 58, 20, 90, 97, 18, 11, 46, 79, 35, 51, 77, 96,
        98,  9, 88, 10, 23, 49, 89, 25, 82, 43, 90, 73, 47, 69, 15, 56, 11, 18,
        84, 23, 83, 12, 13, 43, 47, 51, 63, 73, 92, 76, 62, 25],
       device='cuda:0')
tensor([[ 0.4490, -0.4690,  0.3560,  ...,  1.5654,  1.2822, -0.8501],
        [ 1.7041, -0.2556, -0.2090,  ...,  1.3574,  1.3857, -1.3799],
        [ 0.2196, -0.3623,  0.0080,  ...,  1.0234, -0.4248, -0.7441],
        ...,
        [ 0.4036,  2.2324,  1.1113,  ...,  0.0279,  0.2856,  0.1982],
        [ 0.4060,  0.1084, -0.2239,  ...,  0.3025, -0.5034, -0.7852],
        [ 0.2098, -0.5645, -0.1445,  ...,  3.1133, -0.5161, -0.6724]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 90,   0,  70,  92,  77,  33,  73,  65,  10,  81,  54,   5,  84,  84,
         82,   5,  52, 100,   4,  48,  11,  16,  93,  77,   9,  38,  61,  44,
         55,  34,  58

tensor([[ 0.1226, -1.1396, -0.1638,  ...,  1.4824,  0.1670,  0.0253],
        [-0.1032, -0.8589,  0.2778,  ...,  0.5098,  0.1919, -0.6724],
        [-0.4153, -0.5947, -0.1389,  ...,  1.4717,  0.5132,  1.1465],
        ...,
        [ 1.1357, -0.1853, -0.8384,  ...,  0.6772,  2.0508, -1.2949],
        [-0.3325, -0.7251,  1.0713,  ...,  0.5293, -1.0078,  0.8237],
        [-0.5586, -0.6187, -0.6719,  ...,  0.6714,  0.1816,  1.4180]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([90, 43, 81, 77, 53, 97, 16, 55, 77, 90, 86, 72, 87, 62, 98, 81,  9, 15,
        17, 65, 76, 74,  8, 59, 77, 75, 15, 97, 26, 74, 29, 55, 17, 97, 75, 57,
        41, 50, 84, 13, 66, 99, 14, 73, 43, 28, 67, 13, 54, 88],
       device='cuda:0')
tensor([[ 0.1779, -0.6870, -0.3782,  ...,  0.5684, -0.2197,  2.0547],
        [-0.7012, -1.1865, -0.0371,  ...,  0.0129,  0.2328,  1.9648],
        [-0.2386, -1.1748, -0.8262,  ...,  1.1113,  2.6289,  1.1357],
        ...,
        [ 0.9038, -0.139

tensor([[-0.4568, -0.6382,  0.4766,  ...,  0.2454, -0.0219,  0.0427],
        [ 0.4844, -1.2998,  0.0776,  ...,  1.0332,  1.3779, -0.0423],
        [-0.6768,  0.3462,  1.8975,  ...,  0.0455, -0.2639,  0.7798],
        ...,
        [-1.1035,  0.1583,  0.1689,  ...,  0.3323,  0.0453,  1.0898],
        [ 0.3772, -0.0603,  1.8691,  ...,  0.5322, -0.2102,  0.3835],
        [-0.4282,  0.4082, -0.6558,  ...,  1.1641,  0.0698,  0.4280]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([74, 84, 38,  3, 78, 98, 36, 43, 49, 63, 74, 32, 32, 40, 84, 52, 41, 47,
        64, 92, 74, 69, 63, 97, 45, 53, 11, 22, 66, 73,  4, 84, 49, 49, 40, 92,
        58, 34, 21, 23, 12, 52,  5, 72,  2, 32, 61, 60, 38, 96],
       device='cuda:0')
tensor([[ 0.7227, -0.0316,  0.0795,  ..., -0.0695,  1.9902, -1.4502],
        [ 0.6069, -0.7407,  0.5752,  ...,  1.4424,  0.6162, -0.7412],
        [ 0.2891, -0.9502,  0.1024,  ...,  0.6108,  1.6963,  0.1315],
        ...,
        [ 0.2612, -1.449

tensor([[ 1.1807,  0.4900, -0.3567,  ...,  0.6221,  2.5820, -0.9023],
        [-0.1481,  0.4656,  0.0789,  ..., -0.3020,  1.2119,  0.5410],
        [ 0.0816,  0.7983,  0.8447,  ..., -0.3582, -0.1165, -0.4797],
        ...,
        [-0.1516, -1.1777, -1.1367,  ...,  1.0029,  0.8862, -0.1586],
        [ 0.3889, -1.3086, -0.2546,  ...,  0.1001,  1.6445, -0.4717],
        [-0.1866,  0.4158,  1.5840,  ..., -0.0523, -1.0283,  0.3269]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([ 13,  79,  46,  90,  84,  74,  44,   3,  19,  56,  82,  72,  66,  32,
         61,  73,  70,  83,  48,  18,  77,  53,  89,  81,   3,  56,  11,  72,
         43,  18,  13,  95,  77,  18,  43,  48,  75,  85,  52,  36,  22, 100,
         96,  93, 100,  72,  55,  98,  50,  38], device='cuda:0')
tensor([[-0.2964, -1.2607, -0.5059,  ...,  0.4670,  0.2832,  1.5840],
        [ 0.6924, -0.6870, -0.1552,  ...,  1.2520,  1.5029, -1.3008],
        [-0.0390, -1.2949, -0.1078,  ...,  0.5015,  1.75

tensor([[ 0.1665,  2.0293,  0.6021,  ...,  0.8921, -0.3367, -0.6929],
        [-0.3010, -0.5815,  0.4307,  ...,  0.2515,  0.8740, -0.3599],
        [-0.3298,  1.0420,  0.6997,  ...,  0.0389,  0.6982,  0.7817],
        ...,
        [-0.4773, -1.0361,  0.0920,  ...,  0.4578, -0.1443,  1.1650],
        [ 1.0586, -0.3040,  0.5273,  ...,  0.4595,  1.1768, -1.2930],
        [-0.4451, -1.1934, -0.7192,  ...,  1.6748,  0.5527,  1.5625]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([34, 43, 29, 14, 23, 11, 48, 49, 30, 75, 82, 98, 82, 57, 90,  9, 80, 33,
        81, 66, 33, 71, 21, 49, 91,  3, 79, 26, 86, 23, 89, 71, 98, 42, 26, 84,
        49, 25, 52, 73, 54, 74, 43, 53, 97, 91, 99,  4, 50, 81],
       device='cuda:0')
tensor([[ 0.2131, -0.4004, -0.2375,  ...,  3.9297,  0.2050, -0.0133],
        [ 0.4314, -1.0547, -0.2432,  ..., -0.0235,  1.5352, -0.6528],
        [-0.6255,  0.0710, -0.3757,  ...,  0.5381,  0.9365,  2.1406],
        ...,
        [ 0.3245,  1.625

tensor([[ 1.5015e-01, -1.4492e+00, -4.3066e-01,  ...,  1.5957e+00,
          1.7354e+00,  8.0273e-01],
        [ 1.9629e-01, -1.0654e+00, -4.6509e-01,  ...,  1.3740e+00,
          5.4413e-02, -4.3506e-01],
        [-4.4507e-01, -1.2686e+00, -4.9686e-04,  ...,  1.7451e+00,
          5.1416e-01,  1.0078e+00],
        ...,
        [ 2.4719e-01,  3.7354e-02, -4.8486e-01,  ...,  6.2744e-01,
          4.2310e-01, -2.6343e-01],
        [-2.8296e-01,  1.2676e+00,  1.6621e+00,  ...,  2.1469e-02,
         -5.8057e-01,  2.8223e-01],
        [-8.1970e-02,  5.3955e-01,  9.5996e-01,  ...,  8.1152e-01,
          1.1846e+00,  5.8563e-02]], device='cuda:0', dtype=torch.float16,
       grad_fn=<AddmmBackward>) tensor([84, 56, 81, 22, 59, 80, 73, 73, 11, 81, 63, 59,  7, 91, 49, 18, 96, 62,
        44, 86, 65, 82, 58, 38, 97, 53, 74, 38, 50, 51, 42, 44, 22, 82, 80, 99,
        32, 49,  9, 27, 63, 50, 68, 49, 86, 93, 73, 89, 48, 44],
       device='cuda:0')
tensor([[-4.4327e-03, -5.3467e-01, -1.5356e-01,  

tensor([[ 0.6372,  2.4512,  0.6025,  ...,  0.1429,  1.5664,  0.1305],
        [-0.3457,  0.5288,  0.6436,  ...,  1.3818,  0.3330,  1.7061],
        [ 0.3320, -0.9683, -0.3728,  ...,  0.4167,  1.2803,  0.2661],
        ...,
        [-0.0304, -0.8149,  0.4570,  ...,  0.0048,  0.3008, -0.1832],
        [ 0.5928, -0.4138, -0.6904,  ...,  1.2021,  1.5361, -1.3730],
        [-0.2847, -1.2520, -0.6108,  ...,  1.3711,  0.7036,  0.7793]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([29, 77, 72, 47, 93, 47, 89, 84,  5, 18, 29,  1,  2, 26, 89, 84, 77,  6,
        49, 44, 44, 32, 14, 77, 49, 88, 46, 90, 59, 22, 56, 49, 24, 41, 51, 94,
        32, 61, 34, 77, 28, 88, 74, 80, 19, 76, 50, 43, 49, 36],
       device='cuda:0')
tensor([[ 0.6196, -0.0399, -0.0342,  ...,  0.1350,  4.0586, -0.5210],
        [ 1.4268, -1.1855,  0.0704,  ...,  1.4629,  1.4170, -1.1445],
        [ 0.6982, -0.5010, -0.7905,  ...,  0.7446,  0.5303, -0.8286],
        ...,
        [-0.1863, -0.913

tensor([[ 0.3589, -1.2539, -0.1708,  ...,  0.4653,  1.6533,  0.9424],
        [ 0.7817, -0.0513,  0.6646,  ...,  1.3447,  1.8066, -0.4270],
        [ 0.1144,  1.8838,  0.7749,  ..., -0.1760,  0.2524,  0.5923],
        ...,
        [ 0.6177, -0.7749, -0.4919,  ...,  0.8560,  2.6035,  0.2598],
        [ 0.2089, -0.3596, -1.0840,  ...,  1.1094,  1.3076, -0.3584],
        [ 0.7690, -1.4219, -0.1147,  ...,  2.1934,  1.5713, -0.5566]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) tensor([72, 49, 29, 87, 32, 48, 71, 31,  6, 81, 74, 56, 43, 48, 74, 48, 97, 51,
        82, 73, 32,  6, 55, 14, 81, 66, 97, 60, 90, 54, 76, 57, 12, 37, 22, 45,
        40, 59, 97, 19, 82, 54, 49, 42, 37, 67, 43, 72, 75, 90],
       device='cuda:0')
tensor([[ 0.6143,  1.8301,  0.7163,  ..., -0.4153,  0.5449, -0.7480],
        [-0.3611,  0.3745,  3.3770,  ..., -0.3892, -1.2461,  0.9619],
        [-0.2491, -0.4197,  0.5601,  ...,  0.4526,  0.3892, -1.0010],
        ...,
        [-1.0039, -1.222

KeyboardInterrupt: 

In [30]:
embeddings = torch.from_numpy(best_model.linear.weight.detach().cpu().numpy()).to(torch.float16)

In [31]:
zeroshot_weights = torch.from_numpy(np.array(embeddings).T).to(torch.float16)

def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [
        float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
        for k in topk
    ]

# lazy load
if clip_model == None:
    clip_model, clip_preprocess = clip.load(clip_model_name, device)

with torch.no_grad():
    top1, top5, n = 0.0, 0.0, 0.0
    for i, (images, target) in enumerate(tqdm(test_loader)):
        images = images.cuda()
        target = target.cuda()

        # predict

        image_features = clip_model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * image_features.to(device) @ zeroshot_weights.to(device)

        # measure accuracy
        acc1, _ = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        n += images.size(0)

top1 = (top1 / n) * 100

print("acc:", top1)

  0%|          | 0/303 [00:00<?, ?it/s]

acc: 82.69966996699671


## Deterministic - (using sci-kit)

In [32]:
global clip_model, clip_preprocess
from sklearn.linear_model import LogisticRegression

len_classes = len(classes)

train_features, train_labels = get_clip_features(train_loader)
test_features, test_labels = get_clip_features(test_loader)

train_features = train_features / train_features.norm(dim=-1, keepdim=True)
test_features = test_features / test_features.norm(dim=-1, keepdim=True)

classifier = LogisticRegression(C=1, max_iter=1000, n_jobs=4,verbose=1)
classifier.fit(train_features.cpu().numpy(), train_labels.cpu().numpy())
predictions = classifier.predict(test_features.cpu().numpy())
accuracy = np.mean((test_labels.cpu().numpy() == predictions).astype(np.float)) * 100.0

print(accuracy)

  0%|          | 0/1515 [00:00<?, ?it/s]

  0%|          | 0/303 [00:00<?, ?it/s]

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   1 out of   1 | elapsed:  2.1min finished


83.86138613861385
