
# Dependencies

In [1]:
import custom_models
#python packages
from PIL import Image
from tqdm.notebook import tqdm
#from tqdm import tqdm
import gc
import datetime
import os
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from skimage import io
#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
#torchvision
import torchvision
from torchvision import datasets, models, transforms
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")

PyTorch Version:  1.2.0
Torchvision Version:  0.4.0a0+6b959ee
Using GPU!


# Initialize CNN model

In [2]:
def make_CNN(model_name, num_classes, resume_from = None):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    # The model (nn.Module) to return
    model_ft = None
    # The input image is expected to be (input_size, input_size)
    input_size = 0
    
    # You may NOT use pretrained models!! 
    use_pretrained = False
    
    # By default, all parameters will be trained (useful when you're starting from scratch)
    # Within this function you can set .requires_grad = False for various parameters, if you
    # don't want to learn them

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg13":
        """ VGG13_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg16":
        """ VGG16_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg19":
        """ VGG19_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224

    else:
        raise Exception("Invalid model name!")
    
    return model_ft, input_size

# Define Multimodal Model

In [3]:
class Chest_Disease_Net(nn.Module):
    """
    fc1: number of neurons in the hidden fully connected layer
    """
    def __init__(self, cnn_model_name, num_classes, num_multimodal_features=12, fc1_out=32, resume_from=None):
        #num_classes = 14
        #num_multimodal_features= 12
        super(Chest_Disease_Net, self).__init__()
        self.cnn, self.input_size = make_CNN(cnn_model_name, num_classes)#models.vgg11(pretrained=False, progress = True)
        #define output layers
        self.fc1 = nn.Linear(num_classes + num_multimodal_features, fc1_out) #takes in input of CNN and multimodal input
        self.fc2 = nn.Linear(fc1_out, num_classes)
        if resume_from is not None:
            print("Loading weights from %s" % resume_from)
            self.load_state_dict(torch.load(resume_from))
        
    def forward(self, image, data):
        x1 = self.cnn(image)
        #print("x1", x1.shape)
        x2 = data
        #print("x2", x2.shape)
        #print("x1: ", x1, type(x1))
        #print("x2: ", x2, type(x2))
        #x = torch.cat((x1, x2), dim=1)  
        x = torch.cat((x1.float(), x2.float()), dim=1) ### ???
        #print("concat", x.shape)
        x = F.relu(self.fc1(x))
        #print("relu", x.shape)
        x = self.fc2(x)
        print('forward output: ', x)
       # print("fc2", x.shape)
        return x.double() ### ???

# Data Loading

In [4]:
class MultimodalDataset(Dataset):
    """
    Custom dataset definition
    """
    def __init__(self, csv_path, img_path, transform=None):
        """
        """
        self.df = pd.read_csv(csv_path)
        self.img_path = img_path
        self.transform = transform
        self.diseases = self.get_diseases()
            
    def __getitem__(self, index):
        """
        """
        img_name = self.df.iloc[index]["img_name"] 
        img_path = os.path.join(self.img_path, img_name)
        image = Image.open(img_path)
        image = image.convert("RGB")
        image = np.asarray(image)
        if self.transform is not None:
            image = self.transform(image)
        dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor # ???
        features = np.fromstring(self.df.iloc[index]["feature"][1:-1], sep=",") # ???
        features = torch.from_numpy(features.astype("float")) # ???
        #label = int(self.df.iloc[index]['label'])
        labels = torch.tensor(list(self.df.iloc[index][self.diseases]), dtype = torch.float64)
        #print("Label type: ", type(label))
        #label = np.int_(label) #???
        #print("label type post casting: ", type(label))
        return image, features, labels
        
    def __len__(self):
        return len(self.df)
    
    def get_diseases(self):
        cols = list(self.df.columns)
        cols.remove('disease')
        cols.remove('feature')
        cols.remove('img_name')
        cols.remove('dataset_type')
        return cols
        

In [5]:
def get_dataloaders(input_size, batch_size, num_classes, augment=False, shuffle = True):
    # How to transform the image when you are loading them.
    # you'll likely want to mess with the transforms on the training set.
    
    # For now, we resize/crop the image to the correct input size for our network,
    # then convert it to a [C,H,W] tensor, then normalize it to values with a given mean/stdev. These normalization constants
    # are derived from aggregating lots of data and happen to produce better results.
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            #Add extra transformations for data augmentation
            transforms.RandomApply([
                transforms.RandomChoice([
                    transforms.RandomAffine(degrees=20),
                    transforms.RandomAffine(degrees=0,scale=(0.1, 0.15)),
                    transforms.RandomAffine(degrees=0,translate=(0.2,0.2)),
                    #transforms.RandomAffine(degrees=0,shear=0.15),
                    transforms.RandomHorizontalFlip(p=1.0)
                ] if augment else [transforms.RandomAffine(degrees=0)])#else do nothing
            ], p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ]),
        'test': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ])
    }
    # Create training and validation datasets
    data_subsets = {x: MultimodalDataset(csv_path="./data/"+x+"_dataset{}.csv".format(num_classes), 
                                         img_path="/storage/images", 
                                         transform=data_transforms[x]) for x in data_transforms.keys()}
    # Create training and validation dataloaders
    # Never shuffle the test set
    dataloaders_dict = {x: DataLoader(data_subsets[x], batch_size=batch_size, shuffle=False if x != 'train' else shuffle, num_workers=4) for x in data_transforms.keys()}
    return dataloaders_dict

# Training

In [6]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, model_name=str(datetime.datetime.now()), 
                save_dir = None, save_all_epochs=False, num_epochs=25):
    '''
    model: The NN to train
    dataloaders: A dictionary containing at least the keys 
                 'train','val' that maps to Pytorch data loaders for the dataset
    criterion: The Loss function
    optimizer: The algorithm to update weights 
               (Variations on gradient descent)
    num_epochs: How many epochs to train for
    save_dir: Where to save the best model weights that are found, 
              as they are found. Will save to save_dir/weights_best.pt
              Using None will not write anything to disk
    save_all_epochs: Whether to save weights for ALL epochs, not just the best
                     validation error epoch. Will save to save_dir/weights_e{#}.pt
    '''
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            # TQDM has nice progress bars
            for inputs, features, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                features = features.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs, features)
                    print("model outputs: ", outputs, outputs.size())
                    #print("model labels: ", labels.size())
                    loss = criterion(outputs, labels)

                    # torch.max outputs the maximum value, and its index
                    # Since the input is batched, we take the max along axis 1
                    # (the meaningful outputs)
                    #print("outputs: ", outputs)
                    #preds = torch.max(outputs, 1)
                    preds = (outputs > 0).type(torch.float64)
                    #_, preds = torch.max(outputs, 1)
                    #print("new preds: ", preds)
                    # backprop + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("loss: ", loss.item())
                #print("inputs: ", inputs.size(0), inputs.size())
                
                running_loss += loss.item() * inputs.size(0)
                #print("running loss: ", running_loss)
                print("predictions: ", preds, preds.size())
                print("Labels: ", labels.data, labels.data.size())
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            #is the accuracy calculated correctly?
            print("running_corrects: ", running_corrects.double(), running_corrects.size())
            print("dataloaders len: ", len(dataloaders[phase].dataset)) 
            #maybe dataloaders length * number of classes? Model must predict for all classes
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), save_dir + "/{}_best_weights_1.pt".format(model_name))
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        print()
        scheduler.step()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

# Optimizer & Loss

In [7]:
def make_optimizer(model):
    # Get all the parameters
    params_to_update = model.parameters()
    print("Params to learn:")
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

    # Use SGD
    optimizer = optim.SGD(params_to_update, lr=0.01, momentum=0.9)
    return optimizer

def get_loss(num_classes,device):
    # Create an instance of the loss function
    #set weights to account for unbalanced data.
    #In expectation every category class contributes the same to the loss
    pos_weight = np.array([    0.025944895136267, 0.059170436277992,0.049543908183484,
                            0.124453250588807,0.013364985606939,0.185904145949381,
                            0.108022729821676,0.564067815619275,0.054011364910838])
    pos_weight = torch.tensor(pos_weight,dtype=torch.float64) if num_classes == 9 \
    else torch.tensor([])
    pos_weight = pos_weight.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    return criterion

# Define Experiment Parameters

In [8]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet]
# You can add your own, or modify these however you wish!
model_name = "alexnet"

# Number of classes in the dataset
# Miniplaces has 100
num_classes = 9# set to 9 or 14

# Batch size for training (change depending on how much memory you have)
# You should use a power of 2.
batch_size = 64

# Shuffle the input data?
shuffle_datasets = True

# Number of epochs to train for 
num_epochs = 10

### IO
# Path to a model file to use to start weights at
#resume_from = "/home/ubuntu/6.867-xray-project/weights/data_aug_vgg.pt"
resume_from = None

# Directory to save weights to
save_dir = "weights"
os.makedirs(save_dir, exist_ok=True)

# If True saves the weights for all epochs, else only saves the weight of best one
save_all_epochs = False



# Train!

In [9]:
gc.collect()

7

In [10]:
# Initialize the model for this run

model = Chest_Disease_Net(cnn_model_name = model_name, num_classes = num_classes, resume_from = resume_from)
input_size = model.input_size
dataloaders = get_dataloaders(input_size, batch_size, num_classes, shuffle_datasets)
criterion = get_loss(num_classes=num_classes,device=device)

# Move the model to the gpu if needed
model = model.to(device)

optimizer = make_optimizer(model)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10],gamma=0.1)
# Train the model!
trained_model, validation_history = train_model(model=model, dataloaders=dataloaders, criterion=criterion, optimizer=optimizer,
            scheduler=scheduler, model_name=model_name, save_dir=save_dir, save_all_epochs=save_all_epochs, num_epochs=num_epochs)

Params to learn:
	 cnn.features.0.weight
	 cnn.features.0.bias
	 cnn.features.3.weight
	 cnn.features.3.bias
	 cnn.features.6.weight
	 cnn.features.6.bias
	 cnn.features.8.weight
	 cnn.features.8.bias
	 cnn.features.10.weight
	 cnn.features.10.bias
	 cnn.classifier.1.weight
	 cnn.classifier.1.bias
	 cnn.classifier.4.weight
	 cnn.classifier.4.bias
	 cnn.classifier.6.weight
	 cnn.classifier.6.bias
	 fc1.weight
	 fc1.bias
	 fc2.weight
	 fc2.bias
Epoch 0/9
----------


HBox(children=(IntProgress(value=0, max=1005), HTML(value='')))

forward output:  tensor([[ 0.1709,  0.0773,  0.1751,  0.0766,  0.2944,  0.1314,  0.1272, -0.0021,
          0.0619],
        [ 0.0897,  0.0894,  0.1869,  0.2028,  0.2433,  0.1012,  0.1654, -0.0467,
          0.1009],
        [ 0.1716,  0.0778,  0.1744,  0.0761,  0.2951,  0.1315,  0.1271, -0.0015,
          0.0619],
        [ 0.1253,  0.0733,  0.1479,  0.1797,  0.2802,  0.1728,  0.1598, -0.0167,
          0.0840],
        [ 0.1365,  0.0036,  0.1522,  0.0838,  0.3065,  0.1632,  0.1544,  0.0536,
          0.0463],
        [ 0.1711,  0.0752,  0.1745,  0.0755,  0.2940,  0.1308,  0.1262, -0.0034,
          0.0615],
        [ 0.1604,  0.0640,  0.1267,  0.1059,  0.2131,  0.1181,  0.1653, -0.0437,
          0.0701],
        [ 0.1573,  0.0368,  0.1162,  0.0934,  0.2960,  0.1355,  0.1369,  0.0532,
          0.1108],
        [ 0.1598,  0.1212,  0.1715,  0.1460,  0.3004,  0.1461,  0.1944,  0.0333,
          0.0541],
        [ 0.1709,  0.0769,  0.1752,  0.0763,  0.2945,  0.1319,  0.1277, -0.0027,
  

model outputs:  tensor([[ 0.1246,  0.0742,  0.1485,  0.1795,  0.2793,  0.1709,  0.1584, -0.0169,
          0.0817],
        [ 0.1592,  0.1201,  0.1707,  0.1451,  0.2991,  0.1457,  0.1933,  0.0325,
          0.0529],
        [ 0.1237,  0.0731,  0.1475,  0.1780,  0.2786,  0.1720,  0.1598, -0.0190,
          0.0825],
        [ 0.1596,  0.0625,  0.1265,  0.1039,  0.2130,  0.1176,  0.1655, -0.0440,
          0.0684],
        [ 0.2338,  0.0297,  0.1561,  0.0213,  0.2843,  0.1058,  0.1321, -0.0431,
          0.0265],
        [ 0.1607,  0.0621,  0.1268,  0.1038,  0.2129,  0.1180,  0.1648, -0.0447,
          0.0678],
        [ 0.0783,  0.0944,  0.1389,  0.0702,  0.1867,  0.1271,  0.1831, -0.0495,
          0.1104],
        [ 0.1568,  0.0378,  0.1143,  0.0939,  0.2941,  0.1338,  0.1350,  0.0541,
          0.1112],
        [ 0.1489,  0.1144,  0.1580,  0.1765,  0.2958,  0.1409,  0.1908,  0.0508,
          0.1191],
        [ 0.1596,  0.0630,  0.1259,  0.1047,  0.2130,  0.1181,  0.1647, -0.0442,
   

predictions:  tensor([[1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1.,

forward output:  tensor([[ 0.1423,  0.1056,  0.1514,  0.1687,  0.2875,  0.1363,  0.1858,  0.0483,
          0.1125],
        [ 0.1647,  0.0665,  0.1671,  0.0667,  0.2845,  0.1254,  0.1190, -0.0067,
          0.0531],
        [ 0.1681,  0.0207,  0.1306,  0.0867,  0.2296,  0.1492,  0.1678, -0.0716,
          0.0536],
        [ 0.1171,  0.0648,  0.1404,  0.1714,  0.2700,  0.1660,  0.1523, -0.0183,
          0.0766],
        [ 0.1680,  0.0211,  0.1297,  0.0867,  0.2297,  0.1488,  0.1668, -0.0719,
          0.0538],
        [ 0.1170,  0.0650,  0.1405,  0.1715,  0.2698,  0.1664,  0.1524, -0.0189,
          0.0763],
        [ 0.1506,  0.0280,  0.1075,  0.0849,  0.2861,  0.1283,  0.1293,  0.0519,
          0.1041],
        [ 0.1538,  0.0552,  0.1186,  0.0985,  0.2054,  0.1130,  0.1577, -0.0445,
          0.0644],
        [ 0.1500,  0.0252,  0.1077,  0.0834,  0.2857,  0.1284,  0.1302,  0.0510,
          0.1036],
        [ 0.1171,  0.0653,  0.1396,  0.1724,  0.2697,  0.1672,  0.1514, -0.0186,
  

forward output:  tensor([[ 0.1272,  0.1340,  0.1324,  0.1744,  0.2493,  0.1314,  0.2182, -0.0526,
          0.0751],
        [ 0.1500,  0.1083,  0.1597,  0.1340,  0.2862,  0.1377,  0.1823,  0.0302,
          0.0441],
        [ 0.1642,  0.0170,  0.1262,  0.0830,  0.2254,  0.1464,  0.1636, -0.0722,
          0.0505],
        [ 0.1463,  0.0226,  0.1034,  0.0806,  0.2809,  0.1256,  0.1258,  0.0502,
          0.1013],
        [ 0.1460,  0.0179,  0.1022,  0.0781,  0.2801,  0.1241,  0.1262,  0.0491,
          0.1021],
        [ 0.1135,  0.0614,  0.1363,  0.1677,  0.2654,  0.1636,  0.1487, -0.0197,
          0.0733],
        [ 0.1398,  0.1016,  0.1476,  0.1644,  0.2828,  0.1343,  0.1800,  0.0470,
          0.1086],
        [ 0.1661,  0.0167,  0.1263,  0.0824,  0.2259,  0.1467,  0.1632, -0.0718,
          0.0501],
        [ 0.1388,  0.1026,  0.1473,  0.1654,  0.2827,  0.1337,  0.1801,  0.0475,
          0.1096],
        [ 0.1501,  0.0524,  0.1163,  0.0942,  0.2005,  0.1116,  0.1552, -0.0473,
  

model outputs:  tensor([[ 0.1432,  0.0188,  0.0989,  0.0765,  0.2763,  0.1222,  0.1209,  0.0493,
          0.0975],
        [ 0.1092,  0.0573,  0.1319,  0.1631,  0.2600,  0.1609,  0.1448, -0.0216,
          0.0696],
        [ 0.1098,  0.0570,  0.1318,  0.1625,  0.2598,  0.1613,  0.1454, -0.0223,
          0.0691],
        [ 0.1235,  0.1292,  0.1278,  0.1709,  0.2437,  0.1281,  0.2124, -0.0545,
          0.0713],
        [ 0.1100,  0.0546,  0.1314,  0.1622,  0.2599,  0.1613,  0.1459, -0.0209,
          0.0704],
        [ 0.1094,  0.0568,  0.1322,  0.1632,  0.2604,  0.1607,  0.1445, -0.0209,
          0.0695],
        [ 0.0743,  0.0733,  0.1702,  0.1873,  0.2225,  0.0903,  0.1493, -0.0512,
          0.0874],
        [ 0.1530,  0.0012,  0.0933,  0.0168,  0.2092,  0.1026,  0.0793, -0.0419,
          0.0497],
        [ 0.0486,  0.0079,  0.1622,  0.0522,  0.1930,  0.1132,  0.2219, -0.0832,
          0.0390],
        [ 0.1465,  0.0488,  0.1122,  0.0916,  0.1961,  0.1091,  0.1514, -0.0480,
   

predictions:  tensor([[1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1.,

model outputs:  tensor([[ 1.3566e-01,  3.2216e-02,  1.0013e-01,  7.5070e-02,  1.8062e-01,
          9.9154e-02,  1.3874e-01, -5.0930e-02,  4.3535e-02],
        [ 9.5358e-02,  3.8822e-02,  1.1621e-01,  1.4667e-01,  2.4147e-01,
          1.4971e-01,  1.3105e-01, -2.3935e-02,  5.6778e-02],
        [ 9.6053e-02,  3.8761e-02,  1.1706e-01,  1.4529e-01,  2.4212e-01,
          1.4905e-01,  1.3107e-01, -2.4748e-02,  5.5532e-02],
        [ 1.2834e-01, -6.1770e-04,  8.4094e-02,  5.8741e-02,  2.5824e-01,
          1.1089e-01,  1.0747e-01,  4.4471e-02,  8.3731e-02],
        [ 1.3476e-01,  3.3638e-02,  9.9665e-02,  7.6716e-02,  1.8017e-01,
          9.9131e-02,  1.3797e-01, -5.1275e-02,  4.4668e-02],
        [ 1.3420e-01,  3.3312e-02,  9.8955e-02,  7.6460e-02,  1.7977e-01,
          9.9658e-02,  1.3735e-01, -5.2148e-02,  4.4707e-02],
        [ 1.4118e-01, -1.4958e-02,  7.9507e-02,  2.1011e-03,  1.9315e-01,
          9.3437e-02,  6.6237e-02, -4.5165e-02,  3.7390e-02],
        [ 1.4102e-01, -1.4442e-0

predictions:  tensor([[1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1.,

forward output:  tensor([[ 0.1273,  0.0178,  0.1247,  0.0230,  0.2343,  0.0929,  0.0784, -0.0197,
          0.0180],
        [ 0.1053,  0.0617,  0.1103,  0.1275,  0.2394,  0.1077,  0.1428,  0.0385,
          0.0761],
        [ 0.1052,  0.0624,  0.1106,  0.1282,  0.2394,  0.1080,  0.1425,  0.0383,
          0.0758],
        [ 0.1204,  0.0155,  0.0841,  0.0593,  0.1607,  0.0886,  0.1218, -0.0561,
          0.0292],
        [ 0.1262, -0.0333,  0.0624, -0.0151,  0.1725,  0.0802,  0.0490, -0.0493,
          0.0246],
        [ 0.1268,  0.0186,  0.1256,  0.0234,  0.2343,  0.0934,  0.0797, -0.0196,
          0.0179],
        [ 0.1109, -0.0211,  0.0655,  0.0403,  0.2352,  0.0969,  0.0885,  0.0387,
          0.0684],
        [ 0.0783,  0.0187,  0.0980,  0.1268,  0.2198,  0.1363,  0.1151, -0.0294,
          0.0407],
        [ 0.1051,  0.0615,  0.1104,  0.1277,  0.2391,  0.1075,  0.1422,  0.0378,
          0.0759],
        [ 0.1052,  0.0621,  0.1105,  0.1281,  0.2392,  0.1075,  0.1428,  0.0387,
  

model outputs:  tensor([[ 0.1148,  0.0093,  0.0782,  0.0542,  0.1537,  0.0833,  0.1161, -0.0561,
          0.0252],
        [ 0.0372,  0.0296,  0.1291,  0.1447,  0.1754,  0.0588,  0.1110, -0.0599,
          0.0527],
        [ 0.1214,  0.0122,  0.1171,  0.0167,  0.2268,  0.0888,  0.0709, -0.0208,
          0.0132],
        [ 0.1881, -0.0261,  0.1054, -0.0322,  0.2249,  0.0708,  0.0871, -0.0537,
         -0.0173],
        [ 0.1207, -0.0407,  0.0569, -0.0211,  0.1656,  0.0768,  0.0444, -0.0511,
          0.0187],
        [ 0.0729,  0.0118,  0.0909,  0.1206,  0.2115,  0.1320,  0.1091, -0.0311,
          0.0355],
        [ 0.1274, -0.0347,  0.0827,  0.0359,  0.1728,  0.1131,  0.1212, -0.0821,
          0.0150],
        [ 0.0992,  0.0547,  0.1049,  0.1211,  0.2314,  0.1036,  0.1359,  0.0354,
          0.0696],
        [ 0.0363,  0.0291,  0.1294,  0.1440,  0.1748,  0.0579,  0.1109, -0.0599,
          0.0533],
        [ 0.1146,  0.0088,  0.0787,  0.0532,  0.1540,  0.0839,  0.1166, -0.0571,
   

predictions:  tensor([[1., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1.,

forward output:  tensor([[ 0.0981, -0.0110,  0.0606,  0.0354,  0.1315,  0.0704,  0.0975, -0.0607,
          0.0090],
        [ 0.0537, -0.0111,  0.0700,  0.0987,  0.1877,  0.1152,  0.0912, -0.0336,
          0.0187],
        [ 0.0808,  0.0329,  0.0844,  0.1009,  0.2082,  0.0883,  0.1163,  0.0322,
          0.0526],
        [ 0.1029, -0.0129,  0.0976, -0.0052,  0.2019,  0.0714,  0.0520, -0.0268,
         -0.0047],
        [ 0.1073, -0.0555,  0.0619,  0.0156,  0.1493,  0.0963,  0.1017, -0.0858,
         -0.0015],
        [ 0.0181,  0.0076,  0.1082,  0.1230,  0.1505,  0.0422,  0.0916, -0.0645,
          0.0351],
        [ 0.1045, -0.0626,  0.0380, -0.0418,  0.1419,  0.0612,  0.0251, -0.0561,
          0.0038],
        [ 0.1085, -0.0563,  0.0624,  0.0153,  0.1488,  0.0966,  0.1009, -0.0864,
         -0.0020],
        [ 0.0912,  0.0400,  0.0966,  0.0679,  0.2097,  0.0882,  0.1190,  0.0150,
         -0.0083],
        [ 0.0904,  0.0399,  0.0974,  0.0685,  0.2103,  0.0904,  0.1197,  0.0136,
  

model outputs:  tensor([[ 0.0920, -0.0179,  0.0541,  0.0287,  0.1248,  0.0658,  0.0920, -0.0620,
          0.0029],
        [ 0.0914, -0.0170,  0.0546,  0.0292,  0.1240,  0.0662,  0.0922, -0.0631,
          0.0025],
        [ 0.1009, -0.0625,  0.0543,  0.0093,  0.1414,  0.0903,  0.0941, -0.0860,
         -0.0073],
        [ 0.0734, -0.0742,  0.0843, -0.0488,  0.1408,  0.1079,  0.0966, -0.0552,
         -0.0380],
        [ 0.0968, -0.0199,  0.0899, -0.0115,  0.1934,  0.0663,  0.0444, -0.0285,
         -0.0098],
        [ 0.0742,  0.0262,  0.0771,  0.0950,  0.2004,  0.0832,  0.1095,  0.0312,
          0.0467],
        [ 0.0747,  0.0262,  0.0773,  0.0945,  0.2002,  0.0840,  0.1086,  0.0306,
          0.0460],
        [ 0.0747,  0.0256,  0.0778,  0.0940,  0.2004,  0.0837,  0.1090,  0.0307,
          0.0453],
        [ 0.0745,  0.0258,  0.0770,  0.0946,  0.2003,  0.0827,  0.1094,  0.0317,
          0.0473],
        [ 0.0474, -0.0188,  0.0627,  0.0913,  0.1788,  0.1107,  0.0850, -0.0361,
   

predictions:  tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1.,

forward output:  tensor([[-2.7999e-02, -8.1055e-02,  8.0428e-02, -3.3421e-02,  9.6028e-02,
          4.9144e-02,  1.3952e-01, -1.0291e-01, -3.2272e-02],
        [ 2.8131e-02, -4.2257e-02,  4.1748e-02,  6.7572e-02,  1.5417e-01,
          9.2714e-02,  6.6021e-02, -4.0235e-02, -7.1238e-03],
        [ 2.7694e-02, -4.2688e-02,  4.1880e-02,  6.7921e-02,  1.5422e-01,
          9.2325e-02,  6.6363e-02, -3.9360e-02, -6.6642e-03],
        [ 1.4235e-01, -7.1877e-02,  6.2026e-02, -7.5789e-02,  1.7554e-01,
          3.5976e-02,  4.8857e-02, -5.7718e-02, -5.6171e-02],
        [ 5.9581e-02, -8.5904e-02,  8.1391e-03, -1.9957e-02,  1.6838e-01,
          5.2031e-02,  3.3378e-02,  2.3764e-02,  2.0858e-02],
        [ 5.5167e-02,  2.5609e-03,  5.6615e-02,  7.2953e-02,  1.7556e-01,
          6.7633e-02,  8.8027e-02,  2.6001e-02,  2.7390e-02],
        [ 2.7625e-02, -4.3335e-02,  4.1384e-02,  6.7701e-02,  1.5360e-01,
          9.3164e-02,  6.6924e-02, -3.9955e-02, -6.0409e-03],
        [ 8.3113e-02, -9.0630e-

predictions:  tensor([[1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 1., 1., 0., 1.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 1., 1.],
        [1., 0., 1.,

forward output:  tensor([[ 0.0589, -0.0698,  0.0481, -0.0579,  0.1413,  0.0324,  0.0040, -0.0407,
         -0.0480],
        [ 0.0590, -0.0701,  0.0473, -0.0583,  0.1417,  0.0317,  0.0032, -0.0394,
         -0.0476],
        [ 0.0631, -0.1141,  0.0114, -0.0384,  0.0907,  0.0565,  0.0549, -0.0963,
         -0.0430],
        [ 0.0086, -0.0660,  0.0197,  0.0438,  0.1280,  0.0754,  0.0472, -0.0450,
         -0.0264],
        [ 0.0170,  0.0107,  0.0176,  0.0581,  0.1127,  0.0434,  0.1001, -0.0797,
         -0.0269],
        [-0.0268, -0.0271,  0.0272, -0.0494,  0.0543,  0.0394,  0.0714, -0.0761,
          0.0104],
        [ 0.0178,  0.0097,  0.0176,  0.0570,  0.1130,  0.0430,  0.1004, -0.0794,
         -0.0270],
        [ 0.0587, -0.0705,  0.0483, -0.0581,  0.1409,  0.0318,  0.0037, -0.0409,
         -0.0479],
        [ 0.0597, -0.0711,  0.0470, -0.0577,  0.1418,  0.0314,  0.0035, -0.0393,
         -0.0471],
        [ 0.0567, -0.0607,  0.0163, -0.0118,  0.0787,  0.0382,  0.0539, -0.0718,
  

forward output:  tensor([[ 5.9816e-02, -1.2056e-01, -1.3906e-02, -9.5342e-02,  8.2968e-02,
          2.1836e-02, -2.5097e-02, -6.6683e-02, -4.4216e-02],
        [ 5.1361e-02, -6.7734e-02,  8.7776e-03, -1.7867e-02,  7.1503e-02,
          3.3053e-02,  4.7116e-02, -7.1586e-02, -3.7466e-02],
        [ 5.2916e-02, -7.9233e-02,  3.9914e-02, -6.5398e-02,  1.3290e-01,
          2.5730e-02, -3.4535e-03, -4.1455e-02, -5.3394e-02],
        [ 5.9382e-02, -1.2098e-01, -1.3511e-02, -9.5185e-02,  8.1943e-02,
          2.1255e-02, -2.4288e-02, -6.8218e-02, -4.4267e-02],
        [ 9.9789e-03,  2.6564e-03,  1.0121e-02,  5.0366e-02,  1.0453e-01,
          3.6957e-02,  9.2848e-02, -8.0799e-02, -3.3306e-02],
        [ 5.9503e-02, -1.1990e-01, -1.4180e-02, -9.4496e-02,  8.2511e-02,
          2.2078e-02, -2.5467e-02, -6.6702e-02, -4.3648e-02],
        [ 5.2568e-02, -7.8133e-02,  3.9232e-02, -6.5216e-02,  1.3283e-01,
          2.6187e-02, -5.2004e-03, -4.1688e-02, -5.3167e-02],
        [ 5.9977e-02, -1.2029e-

model outputs:  tensor([[-5.5308e-03, -8.2850e-02,  4.9168e-03,  2.8051e-02,  1.1038e-01,
          6.5040e-02,  3.4308e-02, -4.9435e-02, -3.9492e-02],
        [ 1.0783e-01, -1.0688e-01,  2.9184e-02, -1.0922e-01,  1.3706e-01,
          1.0442e-02,  1.9682e-02, -6.1671e-02, -8.6312e-02],
        [ 2.1394e-02, -3.6530e-02,  2.1413e-02,  3.6449e-02,  1.3404e-01,
          4.1415e-02,  5.4118e-02,  1.7801e-02, -5.0735e-03],
        [ 3.2889e-02, -2.9822e-02,  3.2742e-02,  7.1225e-04,  1.3450e-01,
          3.8642e-02,  5.6009e-02,  1.1692e-03, -6.2395e-02],
        [ 4.6263e-02, -8.7493e-02,  3.3353e-02, -7.4257e-02,  1.2408e-01,
          2.0218e-02, -1.1057e-02, -4.4523e-02, -6.1246e-02],
        [ 5.1925e-02, -1.3061e-01, -2.3110e-03, -5.4371e-02,  7.2900e-02,
          4.4273e-02,  4.0139e-02, -9.8969e-02, -5.5117e-02],
        [ 2.1126e-02, -3.5928e-02,  2.0977e-02,  3.6849e-02,  1.3391e-01,
          4.1716e-02,  5.3322e-02,  1.7407e-02, -5.1383e-03],
        [ 1.0995e-01, -1.0795e-0

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 

forward output:  tensor([[ 2.7654e-02, -1.1323e-01,  1.3064e-02, -9.6856e-02,  9.8342e-02,
          1.1333e-03, -2.9720e-02, -4.8953e-02, -8.0249e-02],
        [ 3.5147e-02, -1.4888e-01, -3.9313e-02, -1.2192e-01,  5.3433e-02,
         -1.1095e-03, -4.8925e-02, -6.9943e-02, -6.8531e-02],
        [-6.4876e-02, -8.5073e-02,  1.9506e-02,  3.1864e-02,  4.6684e-02,
         -2.8971e-02,  1.0277e-02, -8.2475e-02, -4.0626e-02],
        [ 2.5911e-02, -9.5684e-02, -1.6572e-02, -4.3782e-02,  4.1783e-02,
          1.3681e-02,  2.2493e-02, -7.7353e-02, -6.1466e-02],
        [-5.7823e-02, -6.2962e-02, -5.3753e-03, -8.4563e-02,  1.6105e-02,
          1.3518e-02,  3.9390e-02, -8.3632e-02, -1.9854e-02],
        [-1.7927e-02, -2.9483e-02, -1.8776e-02,  1.9891e-02,  7.0941e-02,
          1.3843e-02,  6.4882e-02, -8.6596e-02, -6.0067e-02],
        [ 6.4518e-02, -9.9941e-02,  8.9041e-03, -1.1435e-01,  3.6277e-02,
         -4.3047e-02,  2.5378e-02, -1.2422e-01, -5.1425e-02],
        [ 2.7501e-02, -1.1397e-

predictions:  tensor([[1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 1., 1., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 1., 0.],
        [1., 0., 1., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1., 0., 1., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 1., 1., 1., 1., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 1., 0.],
        [1., 0., 1.,

predictions:  tensor([[1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 0., 0.],
        [1., 0., 0.,

forward output:  tensor([[-0.0534, -0.1406, -0.0476, -0.0262,  0.0512,  0.0235, -0.0090, -0.0576,
         -0.0832],
        [-0.0466, -0.0615, -0.0476, -0.0105,  0.0375, -0.0106,  0.0365, -0.0928,
         -0.0869],
        [-0.0476, -0.0602, -0.0476, -0.0092,  0.0380, -0.0095,  0.0372, -0.0932,
         -0.0874],
        [-0.0472, -0.0610, -0.0477, -0.0101,  0.0372, -0.0110,  0.0357, -0.0937,
         -0.0870],
        [ 0.0022, -0.1252, -0.0411, -0.0719,  0.0118, -0.0064, -0.0015, -0.0825,
         -0.0865],
        [ 0.0004, -0.1456, -0.0181, -0.1280,  0.0644, -0.0219, -0.0593, -0.0558,
         -0.1049],
        [ 0.0089, -0.1780, -0.0655, -0.1482,  0.0243, -0.0241, -0.0721, -0.0737,
         -0.0938],
        [ 0.0088, -0.1774, -0.0663, -0.1480,  0.0250, -0.0232, -0.0734, -0.0723,
         -0.0937],
        [ 0.0004, -0.1461, -0.0174, -0.1266,  0.0639, -0.0224, -0.0587, -0.0555,
         -0.1044],
        [ 0.0003, -0.1460, -0.0185, -0.1269,  0.0646, -0.0220, -0.0585, -0.0550,
  

predictions:  tensor([[1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0.,

forward output:  tensor([[-6.1548e-02, -2.3246e-01, -6.2176e-02, -1.3953e-01,  5.4118e-02,
         -7.2991e-03, -4.8695e-02,  1.0300e-02, -1.4801e-01],
        [-2.3391e-02, -2.2202e-01, -8.3711e-02, -1.3878e-01, -1.6807e-02,
         -2.2430e-02, -3.2735e-02, -1.1160e-01, -1.2439e-01],
        [-1.2174e-01, -1.4703e-01, -3.9250e-02, -2.8553e-02, -2.0992e-02,
         -7.8501e-02, -4.3167e-02, -9.4132e-02, -9.2199e-02],
        [-2.5098e-02, -1.8086e-01, -4.6709e-02, -1.5857e-01,  3.1272e-02,
         -4.7478e-02, -8.6045e-02, -6.0993e-02, -1.3068e-01],
        [-7.6850e-02, -9.2770e-02, -7.6184e-02, -4.0533e-02,  3.9780e-03,
         -3.6568e-02,  8.5507e-03, -9.8660e-02, -1.1351e-01],
        [-1.7065e-02, -2.0718e-01, -9.2085e-02, -1.7525e-01, -4.7161e-03,
         -4.8348e-02, -9.6635e-02, -7.5866e-02, -1.1924e-01],
        [-3.6995e-02, -2.0616e-01, -3.1909e-02, -1.7434e-01,  3.8943e-03,
          1.8871e-02, -1.8593e-02, -8.2663e-02, -1.5789e-01],
        [-3.6603e-02, -2.0550e-

model outputs:  tensor([[ 0.0256, -0.1950, -0.0493, -0.1908,  0.0482, -0.0546, -0.0509, -0.0686,
         -0.1630],
        [-0.0615, -0.1307, -0.0609, -0.0498,  0.0360, -0.0231, -0.0265, -0.0032,
         -0.0853],
        [-0.0285, -0.1607, -0.0740, -0.1036, -0.0234, -0.0317, -0.0326, -0.0863,
         -0.1160],
        [-0.0880, -0.1823, -0.0845, -0.0640,  0.0102, -0.0061, -0.0398, -0.0624,
         -0.1155],
        [-0.0234, -0.2146, -0.0990, -0.1820, -0.0115, -0.0538, -0.1032, -0.0753,
         -0.1259],
        [-0.0489, -0.1243, -0.0513, -0.0888,  0.0344, -0.0324, -0.0250, -0.0181,
         -0.1366],
        [-0.0321, -0.1877, -0.0537, -0.1645,  0.0230, -0.0522, -0.0933, -0.0623,
         -0.1370],
        [-0.1273, -0.1551, -0.0462, -0.0362, -0.0290, -0.0847, -0.0488, -0.0955,
         -0.0995],
        [-0.0281, -0.1609, -0.0724, -0.1041, -0.0244, -0.0317, -0.0322, -0.0877,
         -0.1168],
        [-0.0231, -0.2149, -0.0985, -0.1822, -0.0118, -0.0538, -0.1029, -0.0757,
   

forward output:  tensor([[-0.0344, -0.1684, -0.0791, -0.1111, -0.0306, -0.0366, -0.0379, -0.0888,
         -0.1236],
        [-0.0381, -0.1987, -0.0601, -0.1725,  0.0141, -0.0599, -0.0989, -0.0649,
         -0.1432],
        [-0.0340, -0.1682, -0.0793, -0.1101, -0.0316, -0.0366, -0.0387, -0.0887,
         -0.1227],
        [-0.0950, -0.1891, -0.0905, -0.0712,  0.0021, -0.0124, -0.0464, -0.0645,
         -0.1230],
        [-0.0680, -0.1386, -0.0679, -0.0564,  0.0283, -0.0288, -0.0335, -0.0048,
         -0.0916],
        [-0.0338, -0.1683, -0.0787, -0.1109, -0.0311, -0.0369, -0.0380, -0.0881,
         -0.1231],
        [-0.1342, -0.1624, -0.0526, -0.0434, -0.0372, -0.0922, -0.0547, -0.0967,
         -0.1059],
        [-0.0554, -0.1330, -0.0581, -0.0963,  0.0260, -0.0381, -0.0317, -0.0208,
         -0.1436],
        [-0.0493, -0.2200, -0.0446, -0.1869, -0.0099,  0.0091, -0.0294, -0.0851,
         -0.1700],
        [-0.0892, -0.1100, -0.0896, -0.0554, -0.0120, -0.0473, -0.0047, -0.1015,
  

forward output:  tensor([[ 0.0121, -0.2096, -0.0635, -0.2045,  0.0330, -0.0660, -0.0625, -0.0693,
         -0.1744],
        [-0.1266, -0.1406, -0.0759, -0.1599, -0.0656, -0.0458, -0.0294, -0.0978,
         -0.0871],
        [-0.0066, -0.1754, -0.0626, -0.1868, -0.0447, -0.1013, -0.0437, -0.1373,
         -0.1177],
        [-0.0713, -0.2542, -0.1336, -0.1718,  0.0051, -0.0646, -0.0974, -0.0096,
         -0.1042],
        [-0.0363, -0.2297, -0.1120, -0.1955, -0.0261, -0.0661, -0.1148, -0.0770,
         -0.1387],
        [-0.0753, -0.1459, -0.0743, -0.0633,  0.0205, -0.0337, -0.0392, -0.0068,
         -0.0990],
        [-0.0458, -0.2046, -0.0690, -0.1795,  0.0078, -0.0649, -0.1053, -0.0640,
         -0.1496],
        [-0.0458, -0.2046, -0.0676, -0.1792,  0.0076, -0.0643, -0.1049, -0.0649,
         -0.1508],
        [-0.0708, -0.2491, -0.1342, -0.1689,  0.0055, -0.0656, -0.0984, -0.0069,
         -0.1015],
        [-0.0986, -0.1158, -0.0979, -0.0612, -0.0193, -0.0555, -0.0113, -0.1012,
  

forward output:  tensor([[-5.0293e-02, -2.5466e-01, -1.1287e-01, -1.6826e-01, -4.8861e-02,
         -4.6594e-02, -6.0222e-02, -1.1645e-01, -1.4957e-01],
        [-7.8982e-02, -2.6004e-01, -1.4132e-01, -1.7837e-01, -2.5346e-03,
         -7.1088e-02, -1.0433e-01, -8.6678e-03, -1.0889e-01],
        [-5.1236e-02, -2.1337e-01, -7.6953e-02, -1.8685e-01, -1.5438e-04,
         -7.1813e-02, -1.1395e-01, -6.4614e-02, -1.5549e-01],
        [-7.6964e-02, -2.6251e-01, -1.4126e-01, -1.8012e-01, -2.2776e-03,
         -7.3287e-02, -1.0362e-01, -9.3401e-03, -1.1004e-01],
        [-6.0459e-02, -2.3313e-01, -5.7532e-02, -1.9820e-01, -2.3718e-02,
         -1.6242e-03, -4.1363e-02, -8.6641e-02, -1.8303e-01],
        [-5.1426e-02, -2.1289e-01, -7.6065e-02, -1.8625e-01, -9.4628e-04,
         -7.1583e-02, -1.1348e-01, -6.4878e-02, -1.5507e-01],
        [-4.3408e-02, -2.3602e-01, -1.1944e-01, -2.0144e-01, -3.2632e-02,
         -7.1223e-02, -1.2168e-01, -7.6947e-02, -1.4514e-01],
        [-1.4891e-01, -1.7732e-

forward output:  tensor([[-0.0495, -0.2449, -0.1253, -0.2093, -0.0409, -0.0784, -0.1270, -0.0790,
         -0.1519],
        [-0.0890, -0.1614, -0.0868, -0.0767,  0.0049, -0.0443, -0.0517, -0.0107,
         -0.1125],
        [-0.0526, -0.1897, -0.0977, -0.1295, -0.0524, -0.0516, -0.0560, -0.0920,
         -0.1416],
        [-0.1115, -0.1323, -0.1106, -0.0762, -0.0359, -0.0668, -0.0250, -0.1055,
         -0.1472],
        [-0.1126, -0.1321, -0.1109, -0.0765, -0.0359, -0.0673, -0.0247, -0.1053,
         -0.1472],
        [-0.0493, -0.2446, -0.1248, -0.2095, -0.0406, -0.0786, -0.1265, -0.0792,
         -0.1525],
        [-0.0498, -0.2434, -0.1251, -0.2080, -0.0403, -0.0773, -0.1266, -0.0789,
         -0.1518],
        [-0.0501, -0.2448, -0.1256, -0.2096, -0.0409, -0.0787, -0.1272, -0.0798,
         -0.1522],
        [-0.0853, -0.2692, -0.1491, -0.1862, -0.0110, -0.0782, -0.1116, -0.0106,
         -0.1150],
        [-0.0491, -0.2455, -0.1248, -0.2094, -0.0416, -0.0788, -0.1265, -0.0797,
  

forward output:  tensor([[-0.0896, -0.1734, -0.0913, -0.1342, -0.0140, -0.0686, -0.0636, -0.0284,
         -0.1747],
        [-0.1301, -0.2306, -0.1259, -0.1087, -0.0372, -0.0419, -0.0761, -0.0705,
         -0.1571],
        [-0.0789, -0.2549, -0.0756, -0.2182, -0.0441, -0.0167, -0.0581, -0.0892,
         -0.2027],
        [-0.0325, -0.2034, -0.0879, -0.2134, -0.0742, -0.1232, -0.0680, -0.1422,
         -0.1413],
        [-0.0639, -0.2593, -0.1393, -0.2222, -0.0546, -0.0898, -0.1394, -0.0807,
         -0.1653],
        [-0.1263, -0.1478, -0.1242, -0.0915, -0.0523, -0.0797, -0.0381, -0.1082,
         -0.1604],
        [-0.0642, -0.2034, -0.1090, -0.1414, -0.0668, -0.0621, -0.0678, -0.0945,
         -0.1540],
        [-0.0796, -0.2548, -0.0759, -0.2185, -0.0447, -0.0158, -0.0584, -0.0907,
         -0.2033],
        [-0.0145, -0.2398, -0.0897, -0.2315,  0.0038, -0.0874, -0.0863, -0.0723,
         -0.2013],
        [-0.0713, -0.2376, -0.0968, -0.2084, -0.0241, -0.0899, -0.1327, -0.0691,
  

forward output:  tensor([[-0.0389, -0.2100, -0.0942, -0.2187, -0.0809, -0.1280, -0.0737, -0.1430,
         -0.1476],
        [-0.0850, -0.2611, -0.0821, -0.2240, -0.0512, -0.0218, -0.0645, -0.0908,
         -0.2104],
        [-0.0854, -0.2611, -0.0824, -0.2238, -0.0516, -0.0211, -0.0652, -0.0921,
         -0.2096],
        [-0.1338, -0.1551, -0.1313, -0.0979, -0.0595, -0.0851, -0.0447, -0.1093,
         -0.1680],
        [-0.0709, -0.2098, -0.1154, -0.1477, -0.0728, -0.0668, -0.0731, -0.0950,
         -0.1600],
        [-0.1335, -0.1561, -0.1317, -0.0988, -0.0595, -0.0861, -0.0449, -0.1091,
         -0.1681],
        [-0.1330, -0.1557, -0.1314, -0.0983, -0.0600, -0.0857, -0.0451, -0.1094,
         -0.1674],
        [-0.1066, -0.2938, -0.1704, -0.2080, -0.0348, -0.0959, -0.1299, -0.0137,
         -0.1353],
        [-0.1764, -0.2079, -0.0941, -0.0865, -0.0844, -0.1303, -0.0924, -0.1038,
         -0.1447],
        [-0.0777, -0.2890, -0.1408, -0.1985, -0.0806, -0.0722, -0.0840, -0.1213,
  

Labels:  tensor([[0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

forward output:  tensor([[-0.1263, -0.3193, -0.1925, -0.2314, -0.0586, -0.1155, -0.1524, -0.0180,
         -0.1558],
        [-0.1038, -0.2817, -0.1009, -0.2422, -0.0719, -0.0371, -0.0818, -0.0936,
         -0.2301],
        [-0.0975, -0.2706, -0.1239, -0.2380, -0.0553, -0.1140, -0.1581, -0.0743,
         -0.2023],
        [-0.0899, -0.2304, -0.1337, -0.1660, -0.0938, -0.0824, -0.0905, -0.0977,
         -0.1785],
        [-0.1985, -0.2306, -0.1159, -0.1076, -0.1094, -0.1489, -0.1123, -0.1073,
         -0.1618],
        [-0.0900, -0.2903, -0.1655, -0.2502, -0.0835, -0.1148, -0.1629, -0.0833,
         -0.1920],
        [-0.0980, -0.3126, -0.1628, -0.2199, -0.1033, -0.0909, -0.1041, -0.1230,
         -0.1949],
        [-0.0898, -0.2896, -0.1645, -0.2498, -0.0836, -0.1143, -0.1619, -0.0841,
         -0.1923],
        [-0.1162, -0.2049, -0.1185, -0.1618, -0.0459, -0.0939, -0.0888, -0.0345,
         -0.1985],
        [-0.0987, -0.3124, -0.1624, -0.2198, -0.1032, -0.0907, -0.1034, -0.1230,
  

forward output:  tensor([[-0.2052, -0.2376, -0.1232, -0.1139, -0.1177, -0.1557, -0.1183, -0.1089,
         -0.1676],
        [-0.1342, -0.3262, -0.1993, -0.2381, -0.0656, -0.1206, -0.1573, -0.0177,
         -0.1619],
        [-0.0961, -0.2980, -0.1723, -0.2567, -0.0913, -0.1216, -0.1691, -0.0834,
         -0.1982],
        [-0.1372, -0.2135, -0.1329, -0.1232, -0.0478, -0.0821, -0.0952, -0.0203,
         -0.1582],
        [-0.1345, -0.3296, -0.1990, -0.2402, -0.0659, -0.1197, -0.1581, -0.0196,
         -0.1638],
        [-0.2050, -0.2389, -0.1225, -0.1155, -0.1175, -0.1565, -0.1183, -0.1082,
         -0.1682],
        [-0.1653, -0.2721, -0.1607, -0.1449, -0.0757, -0.0720, -0.1053, -0.0761,
         -0.1897],
        [-0.0637, -0.2373, -0.1187, -0.2451, -0.1087, -0.1486, -0.0973, -0.1470,
         -0.1726],
        [-0.1240, -0.2119, -0.1255, -0.1685, -0.0526, -0.0986, -0.0946, -0.0357,
         -0.2058],
        [-0.0973, -0.2944, -0.1731, -0.2543, -0.0906, -0.1197, -0.1703, -0.0829,
  

forward output:  tensor([[-0.1721, -0.2794, -0.1679, -0.1516, -0.0839, -0.0781, -0.1117, -0.0766,
         -0.1964],
        [-0.1037, -0.3043, -0.1791, -0.2626, -0.0975, -0.1261, -0.1751, -0.0840,
         -0.2056],
        [-0.1696, -0.1950, -0.1657, -0.1346, -0.0983, -0.1182, -0.0774, -0.1139,
         -0.2015],
        [-0.1720, -0.2804, -0.1675, -0.1527, -0.0835, -0.0786, -0.1114, -0.0768,
         -0.1958],
        [-0.1008, -0.2456, -0.1460, -0.1797, -0.1071, -0.0921, -0.1021, -0.0994,
         -0.1921],
        [-0.1107, -0.2866, -0.1392, -0.2519, -0.0710, -0.1265, -0.1730, -0.0756,
         -0.2139],
        [-0.1718, -0.2799, -0.1680, -0.1516, -0.0837, -0.0777, -0.1120, -0.0769,
         -0.1966],
        [-0.1438, -0.2213, -0.1391, -0.1299, -0.0552, -0.0873, -0.1011, -0.0217,
         -0.1654],
        [-0.1027, -0.3055, -0.1786, -0.2633, -0.0979, -0.1267, -0.1746, -0.0840,
         -0.2050],
        [-0.2170, -0.2980, -0.1119, -0.2407, -0.1281, -0.1152, -0.0452, -0.1427,
  

forward output:  tensor([[-0.1469, -0.3437, -0.2132, -0.2533, -0.0815, -0.1336, -0.1703, -0.0203,
         -0.1762],
        [-0.1232, -0.3034, -0.1189, -0.2630, -0.0924, -0.0511, -0.0980, -0.0973,
         -0.2492],
        [-0.1463, -0.3429, -0.2136, -0.2531, -0.0817, -0.1342, -0.1712, -0.0204,
         -0.1764],
        [-0.1172, -0.2944, -0.1452, -0.2582, -0.0781, -0.1317, -0.1785, -0.0772,
         -0.2216],
        [-0.1479, -0.3456, -0.2140, -0.2544, -0.0819, -0.1332, -0.1711, -0.0213,
         -0.1762],
        [-0.1791, -0.2883, -0.1753, -0.1593, -0.0909, -0.0845, -0.1179, -0.0772,
         -0.2028],
        [-0.1760, -0.2022, -0.1718, -0.1415, -0.1069, -0.1240, -0.0841, -0.1156,
         -0.2072],
        [-0.1092, -0.3130, -0.1851, -0.2711, -0.1052, -0.1335, -0.1810, -0.0850,
         -0.2127],
        [-0.1187, -0.3374, -0.1837, -0.2421, -0.1261, -0.1089, -0.1217, -0.1259,
         -0.2142],
        [-0.1087, -0.3112, -0.1857, -0.2698, -0.1052, -0.1332, -0.1821, -0.0839,
  

forward output:  tensor([[-0.1167, -0.3190, -0.1926, -0.2760, -0.1122, -0.1389, -0.1877, -0.0852,
         -0.2183],
        [-0.1257, -0.3469, -0.1909, -0.2498, -0.1338, -0.1152, -0.1278, -0.1264,
         -0.2208],
        [-0.1440, -0.2350, -0.1451, -0.1903, -0.0752, -0.1169, -0.1130, -0.0391,
         -0.2247],
        [-0.1302, -0.3104, -0.1256, -0.2689, -0.0994, -0.0563, -0.1046, -0.0983,
         -0.2557],
        [-0.2255, -0.2591, -0.1431, -0.1342, -0.1403, -0.1736, -0.1368, -0.1119,
         -0.1866],
        [-0.1856, -0.2962, -0.1813, -0.1668, -0.0977, -0.0908, -0.1237, -0.0769,
         -0.2103],
        [-0.1139, -0.2582, -0.1585, -0.1908, -0.1199, -0.1018, -0.1138, -0.1011,
         -0.2041],
        [-0.2244, -0.2608, -0.1428, -0.1358, -0.1409, -0.1762, -0.1359, -0.1112,
         -0.1862],
        [-0.1131, -0.2590, -0.1579, -0.1913, -0.1199, -0.1037, -0.1128, -0.0995,
         -0.2035],
        [-0.2016, -0.2226, -0.1499, -0.2373, -0.1506, -0.1095, -0.1012, -0.1118,
  

forward output:  tensor([[-0.0762, -0.3054, -0.1511, -0.2912, -0.0607, -0.1386, -0.1401, -0.0747,
         -0.2580],
        [-0.1232, -0.3285, -0.1991, -0.2840, -0.1193, -0.1451, -0.1932, -0.0858,
         -0.2256],
        [-0.1303, -0.3117, -0.1599, -0.2740, -0.0936, -0.1450, -0.1915, -0.0785,
         -0.2346],
        [-0.2322, -0.2671, -0.1496, -0.1416, -0.1481, -0.1809, -0.1427, -0.1123,
         -0.1927],
        [-0.1201, -0.2654, -0.1649, -0.1974, -0.1263, -0.1077, -0.1190, -0.1008,
         -0.2103],
        [-0.1238, -0.3277, -0.1992, -0.2837, -0.1193, -0.1451, -0.1932, -0.0865,
         -0.2258],
        [-0.1202, -0.2652, -0.1647, -0.1968, -0.1268, -0.1082, -0.1188, -0.1008,
         -0.2099],
        [-0.2080, -0.2300, -0.1557, -0.2445, -0.1578, -0.1150, -0.1073, -0.1135,
         -0.1682],
        [-0.1348, -0.3167, -0.1310, -0.2748, -0.1054, -0.0619, -0.1102, -0.0985,
         -0.2627],
        [-0.1932, -0.3039, -0.1901, -0.1725, -0.1080, -0.0962, -0.1306, -0.0806,
  

forward output:  tensor([[-0.1575, -0.2503, -0.1581, -0.2043, -0.0909, -0.1287, -0.1257, -0.0424,
         -0.2371],
        [-0.1981, -0.2264, -0.1923, -0.1622, -0.1290, -0.1434, -0.1028, -0.1193,
         -0.2281],
        [-0.1372, -0.3196, -0.1660, -0.2802, -0.1011, -0.1510, -0.1975, -0.0794,
         -0.2412],
        [-0.1711, -0.2509, -0.1655, -0.1547, -0.0845, -0.1110, -0.1268, -0.0267,
         -0.1904],
        [-0.1997, -0.3129, -0.1961, -0.1817, -0.1141, -0.1034, -0.1368, -0.0806,
         -0.2230],
        [-0.0961, -0.2704, -0.1500, -0.2763, -0.1427, -0.1771, -0.1271, -0.1507,
         -0.2029],
        [-0.1979, -0.2268, -0.1922, -0.1628, -0.1295, -0.1427, -0.1035, -0.1202,
         -0.2290],
        [-0.1999, -0.3140, -0.1963, -0.1822, -0.1141, -0.1041, -0.1362, -0.0808,
         -0.2214],
        [-0.1966, -0.2265, -0.1919, -0.1629, -0.1293, -0.1429, -0.1037, -0.1185,
         -0.2279],
        [-0.1675, -0.3713, -0.2363, -0.2775, -0.1053, -0.1536, -0.1922, -0.0242,
  

forward output:  tensor([[-0.1319, -0.2790, -0.1761, -0.2085, -0.1403, -0.1191, -0.1307, -0.1022,
         -0.2220],
        [-0.1013, -0.2775, -0.1561, -0.2829, -0.1497, -0.1815, -0.1332, -0.1515,
         -0.2088],
        [-0.2215, -0.2431, -0.1680, -0.2561, -0.1718, -0.1261, -0.1188, -0.1149,
         -0.1800],
        [-0.1319, -0.2789, -0.1764, -0.2090, -0.1393, -0.1183, -0.1306, -0.1019,
         -0.2225],
        [-0.1466, -0.3704, -0.2130, -0.2712, -0.1571, -0.1346, -0.1484, -0.1294,
         -0.2403],
        [-0.1326, -0.2783, -0.1769, -0.2086, -0.1394, -0.1183, -0.1303, -0.1018,
         -0.2224],
        [-0.1461, -0.3710, -0.2141, -0.2717, -0.1570, -0.1346, -0.1490, -0.1287,
         -0.2401],
        [-0.1318, -0.2792, -0.1766, -0.2094, -0.1391, -0.1186, -0.1304, -0.1014,
         -0.2227],
        [-0.1633, -0.2601, -0.1642, -0.2130, -0.0987, -0.1360, -0.1320, -0.0431,
         -0.2438],
        [-0.2212, -0.2432, -0.1677, -0.2566, -0.1714, -0.1260, -0.1185, -0.1151,
  

forward output:  tensor([[-0.1507, -0.3346, -0.1814, -0.2940, -0.1161, -0.1632, -0.2127, -0.0801,
         -0.2536],
        [-0.0974, -0.3265, -0.1720, -0.3108, -0.0820, -0.1567, -0.1585, -0.0749,
         -0.2767],
        [-0.1854, -0.2653, -0.1771, -0.1680, -0.0990, -0.1215, -0.1373, -0.0296,
         -0.2039],
        [-0.2139, -0.3297, -0.2118, -0.1947, -0.1302, -0.1172, -0.1488, -0.0807,
         -0.2334],
        [-0.1825, -0.3869, -0.2495, -0.2920, -0.1197, -0.1646, -0.2042, -0.0245,
         -0.2103],
        [-0.2118, -0.2431, -0.2058, -0.1772, -0.1441, -0.1570, -0.1154, -0.1206,
         -0.2416],
        [-0.2134, -0.3297, -0.2116, -0.1954, -0.1301, -0.1169, -0.1497, -0.0818,
         -0.2350],
        [-0.1820, -0.3878, -0.2503, -0.2929, -0.1203, -0.1659, -0.2057, -0.0257,
         -0.2103],
        [-0.0975, -0.3270, -0.1722, -0.3109, -0.0825, -0.1565, -0.1585, -0.0750,
         -0.2769],
        [-0.1390, -0.2859, -0.1835, -0.2156, -0.1463, -0.1251, -0.1371, -0.1016,
  

forward output:  tensor([[-0.2591, -0.2961, -0.1769, -0.1684, -0.1776, -0.2064, -0.1671, -0.1152,
         -0.2175],
        [-0.1893, -0.3966, -0.2573, -0.3002, -0.1281, -0.1718, -0.2113, -0.0267,
         -0.2173],
        [-0.1564, -0.3430, -0.1884, -0.3012, -0.1237, -0.1705, -0.2198, -0.0802,
         -0.2595],
        [-0.1568, -0.3443, -0.1874, -0.3017, -0.1242, -0.1709, -0.2181, -0.0810,
         -0.2599],
        [-0.1569, -0.3426, -0.1893, -0.3011, -0.1234, -0.1702, -0.2203, -0.0801,
         -0.2592],
        [-0.1888, -0.3952, -0.2570, -0.2988, -0.1280, -0.1720, -0.2106, -0.0254,
         -0.2164],
        [-0.1142, -0.2899, -0.1685, -0.2943, -0.1620, -0.1928, -0.1439, -0.1519,
         -0.2215],
        [-0.1602, -0.3865, -0.2279, -0.2857, -0.1721, -0.1470, -0.1616, -0.1306,
         -0.2538],
        [-0.2195, -0.2507, -0.2128, -0.1846, -0.1518, -0.1640, -0.1220, -0.1214,
         -0.2485],
        [-0.1569, -0.3417, -0.1882, -0.3009, -0.1231, -0.1696, -0.2191, -0.0804,
  

forward output:  tensor([[-0.1105, -0.3424, -0.1858, -0.3254, -0.0966, -0.1685, -0.1707, -0.0757,
         -0.2900],
        [-0.1576, -0.3660, -0.2335, -0.3179, -0.1554, -0.1771, -0.2243, -0.0884,
         -0.2590],
        [-0.1988, -0.2806, -0.1897, -0.1816, -0.1133, -0.1334, -0.1496, -0.0316,
         -0.2171],
        [-0.1574, -0.3664, -0.2338, -0.3186, -0.1559, -0.1781, -0.2249, -0.0883,
         -0.2591],
        [-0.1964, -0.4053, -0.2648, -0.3079, -0.1360, -0.1784, -0.2178, -0.0272,
         -0.2231],
        [-0.2270, -0.3471, -0.2255, -0.2111, -0.1449, -0.1302, -0.1621, -0.0835,
         -0.2484],
        [-0.2252, -0.2583, -0.2190, -0.1915, -0.1594, -0.1707, -0.1291, -0.1217,
         -0.2546],
        [-0.2658, -0.3016, -0.1833, -0.1734, -0.1852, -0.2118, -0.1726, -0.1166,
         -0.2235],
        [-0.1676, -0.3954, -0.2350, -0.2937, -0.1794, -0.1531, -0.1674, -0.1311,
         -0.2609],
        [-0.1985, -0.2813, -0.1897, -0.1820, -0.1133, -0.1333, -0.1492, -0.0321,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [1., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

forward output:  tensor([[-0.1830, -0.3761, -0.2170, -0.3309, -0.1530, -0.1965, -0.2459, -0.0830,
         -0.2868],
        [-0.2473, -0.2829, -0.2397, -0.2128, -0.1810, -0.1919, -0.1470, -0.1237,
         -0.2750],
        [-0.1689, -0.3212, -0.2125, -0.2477, -0.1802, -0.1524, -0.1656, -0.1062,
         -0.2603],
        [-0.2478, -0.2821, -0.2395, -0.2119, -0.1816, -0.1906, -0.1473, -0.1255,
         -0.2751],
        [-0.2182, -0.3053, -0.2084, -0.2036, -0.1343, -0.1520, -0.1670, -0.0345,
         -0.2378],
        [-0.2185, -0.3048, -0.2087, -0.2032, -0.1345, -0.1515, -0.1678, -0.0352,
         -0.2383],
        [-0.1306, -0.3653, -0.2064, -0.3462, -0.1177, -0.1865, -0.1882, -0.0769,
         -0.3091],
        [-0.1783, -0.3884, -0.2553, -0.3390, -0.1778, -0.1974, -0.2447, -0.0899,
         -0.2791],
        [-0.2057, -0.3039, -0.2042, -0.2540, -0.1428, -0.1728, -0.1682, -0.0493,
         -0.2815],
        [-0.1688, -0.3209, -0.2137, -0.2476, -0.1808, -0.1528, -0.1669, -0.1050,
  

forward output:  tensor([[-0.1954, -0.4271, -0.2644, -0.3226, -0.2094, -0.1796, -0.1937, -0.1331,
         -0.2881],
        [-0.2548, -0.3793, -0.2550, -0.2389, -0.1754, -0.1559, -0.1870, -0.0854,
         -0.2736],
        [-0.1851, -0.3956, -0.2626, -0.3455, -0.1844, -0.2029, -0.2516, -0.0891,
         -0.2859],
        [-0.1473, -0.3232, -0.1991, -0.3254, -0.1951, -0.2196, -0.1720, -0.1553,
         -0.2548],
        [-0.2252, -0.3134, -0.2150, -0.2103, -0.1418, -0.1575, -0.1736, -0.0371,
         -0.2457],
        [-0.2549, -0.2900, -0.2464, -0.2196, -0.1892, -0.1985, -0.1540, -0.1261,
         -0.2818],
        [-0.2249, -0.4377, -0.2938, -0.3373, -0.1667, -0.2026, -0.2455, -0.0308,
         -0.2506],
        [-0.2240, -0.4358, -0.2934, -0.3365, -0.1656, -0.2035, -0.2451, -0.0296,
         -0.2510],
        [-0.1859, -0.3989, -0.2619, -0.3476, -0.1850, -0.2034, -0.2486, -0.0914,
         -0.2863],
        [-0.2546, -0.2899, -0.2455, -0.2211, -0.1894, -0.1961, -0.1537, -0.1263,
  

forward output:  tensor([[-0.2796, -0.3075, -0.2211, -0.3180, -0.2341, -0.1770, -0.1698, -0.1251,
         -0.2391],
        [-0.2093, -0.4470, -0.2794, -0.3397, -0.2250, -0.1930, -0.2059, -0.1355,
         -0.3023],
        [-0.2257, -0.3285, -0.2243, -0.2771, -0.1652, -0.1925, -0.1865, -0.0526,
         -0.3011],
        [-0.2036, -0.4028, -0.2384, -0.3540, -0.1766, -0.2154, -0.2670, -0.0864,
         -0.3076],
        [-0.2387, -0.4533, -0.3077, -0.3526, -0.1809, -0.2157, -0.2582, -0.0312,
         -0.2654],
        [-0.2047, -0.3914, -0.1988, -0.3447, -0.1802, -0.1195, -0.1744, -0.1071,
         -0.3368],
        [-0.2690, -0.3978, -0.2698, -0.2543, -0.1904, -0.1698, -0.1996, -0.0865,
         -0.2859],
        [-0.2069, -0.3939, -0.2000, -0.3465, -0.1800, -0.1182, -0.1732, -0.1074,
         -0.3361],
        [-0.2701, -0.3054, -0.2591, -0.2353, -0.2044, -0.2109, -0.1660, -0.1280,
         -0.2962],
        [-0.1998, -0.4131, -0.2770, -0.3608, -0.1990, -0.2167, -0.2636, -0.0912,
  

forward output:  tensor([[-0.2123, -0.4005, -0.2065, -0.3523, -0.1865, -0.1246, -0.1793, -0.1065,
         -0.3440],
        [-0.2763, -0.4056, -0.2771, -0.2624, -0.1979, -0.1768, -0.2064, -0.0874,
         -0.2935],
        [-0.1936, -0.3495, -0.2396, -0.2744, -0.2088, -0.1766, -0.1919, -0.1070,
         -0.2829],
        [-0.2452, -0.3376, -0.2338, -0.2316, -0.1625, -0.1765, -0.1908, -0.0397,
         -0.2663],
        [-0.2068, -0.4222, -0.2833, -0.3690, -0.2066, -0.2225, -0.2682, -0.0926,
         -0.3072],
        [-0.3130, -0.3511, -0.2297, -0.2211, -0.2361, -0.2567, -0.2149, -0.1221,
         -0.2670],
        [-0.1593, -0.3966, -0.2350, -0.3747, -0.1468, -0.2120, -0.2127, -0.0790,
         -0.3356],
        [-0.1672, -0.3431, -0.2171, -0.3459, -0.2154, -0.2373, -0.1897, -0.1570,
         -0.2732],
        [-0.2097, -0.4089, -0.2460, -0.3608, -0.1829, -0.2215, -0.2742, -0.0862,
         -0.3147],
        [-0.2444, -0.4651, -0.3156, -0.3630, -0.1899, -0.2239, -0.2670, -0.0329,
  

forward output:  tensor([[-0.2521, -0.3453, -0.2393, -0.2390, -0.1691, -0.1824, -0.1952, -0.0404,
         -0.2736],
        [-0.2536, -0.3446, -0.2394, -0.2381, -0.1702, -0.1817, -0.1960, -0.0430,
         -0.2738],
        [-0.2164, -0.4186, -0.2526, -0.3696, -0.1909, -0.2283, -0.2804, -0.0879,
         -0.3224],
        [-0.2233, -0.4610, -0.2935, -0.3533, -0.2396, -0.2058, -0.2195, -0.1362,
         -0.3160],
        [-0.2831, -0.4127, -0.2836, -0.2686, -0.2056, -0.1823, -0.2124, -0.0884,
         -0.3011],
        [-0.2117, -0.4311, -0.2901, -0.3773, -0.2135, -0.2300, -0.2758, -0.0912,
         -0.3147],
        [-0.2846, -0.3238, -0.2732, -0.2519, -0.2198, -0.2256, -0.1788, -0.1311,
         -0.3113],
        [-0.2931, -0.3214, -0.2335, -0.3309, -0.2475, -0.1884, -0.1810, -0.1265,
         -0.2526],
        [-0.2520, -0.3451, -0.2396, -0.2389, -0.1691, -0.1819, -0.1952, -0.0403,
         -0.2731],
        [-0.2827, -0.3247, -0.2721, -0.2513, -0.2200, -0.2239, -0.1790, -0.1329,
  

forward output:  tensor([[-0.2310, -0.4382, -0.2669, -0.3850, -0.2070, -0.2417, -0.2931, -0.0890,
         -0.3353],
        [-0.2662, -0.4868, -0.3374, -0.3835, -0.2114, -0.2425, -0.2858, -0.0342,
         -0.2933],
        [-0.3061, -0.3355, -0.2450, -0.3446, -0.2609, -0.1994, -0.1918, -0.1291,
         -0.2658],
        [-0.2313, -0.4371, -0.2661, -0.3848, -0.2061, -0.2407, -0.2923, -0.0895,
         -0.3363],
        [-0.3333, -0.3716, -0.2490, -0.2409, -0.2574, -0.2750, -0.2324, -0.1250,
         -0.2860],
        [-0.2374, -0.4792, -0.3087, -0.3694, -0.2551, -0.2189, -0.2320, -0.1385,
         -0.3300],
        [-0.2282, -0.4457, -0.3055, -0.3913, -0.2285, -0.2429, -0.2886, -0.0942,
         -0.3286],
        [-0.2978, -0.4309, -0.2997, -0.2845, -0.2219, -0.1968, -0.2249, -0.0917,
         -0.3123],
        [-0.2275, -0.4439, -0.3052, -0.3897, -0.2281, -0.2421, -0.2891, -0.0934,
         -0.3280],
        [-0.2681, -0.4882, -0.3375, -0.3838, -0.2117, -0.2406, -0.2850, -0.0341,
  

forward output:  tensor([[-0.2734, -0.3683, -0.2587, -0.2591, -0.1905, -0.2002, -0.2135, -0.0456,
         -0.2950],
        [-0.2194, -0.3783, -0.2648, -0.3020, -0.2360, -0.2004, -0.2157, -0.1103,
         -0.3091],
        [-0.2603, -0.3688, -0.2579, -0.3144, -0.2022, -0.2252, -0.2171, -0.0587,
         -0.3340],
        [-0.2401, -0.4307, -0.2306, -0.3817, -0.2128, -0.1447, -0.2028, -0.1114,
         -0.3722],
        [-0.2453, -0.4858, -0.3163, -0.3754, -0.2625, -0.2259, -0.2393, -0.1390,
         -0.3364],
        [-0.3398, -0.3793, -0.2556, -0.2478, -0.2648, -0.2814, -0.2387, -0.1259,
         -0.2927],
        [-0.2193, -0.3776, -0.2651, -0.3013, -0.2354, -0.1997, -0.2156, -0.1092,
         -0.3079],
        [-0.3043, -0.4392, -0.3064, -0.2914, -0.2282, -0.2031, -0.2317, -0.0902,
         -0.3206],
        [-0.2726, -0.3689, -0.2590, -0.2596, -0.1902, -0.2015, -0.2139, -0.0441,
         -0.2938],
        [-0.2351, -0.4555, -0.3121, -0.3992, -0.2365, -0.2498, -0.2938, -0.0954,
  

Labels:  tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 

forward output:  tensor([[-0.3335, -0.4741, -0.3364, -0.3223, -0.2589, -0.2312, -0.2573, -0.0937,
         -0.3488],
        [-0.2468, -0.4075, -0.2908, -0.3285, -0.2627, -0.2250, -0.2396, -0.1120,
         -0.3339],
        [-0.2657, -0.4884, -0.3429, -0.4287, -0.2665, -0.2763, -0.3201, -0.0979,
         -0.3634],
        [-0.2170, -0.4615, -0.2927, -0.4334, -0.2049, -0.2633, -0.2614, -0.0842,
         -0.3894],
        [-0.3000, -0.4024, -0.2832, -0.2885, -0.2177, -0.2263, -0.2355, -0.0495,
         -0.3241],
        [-0.2663, -0.4599, -0.2586, -0.4081, -0.2407, -0.1680, -0.2273, -0.1128,
         -0.4004],
        [-0.2661, -0.4804, -0.3043, -0.4233, -0.2436, -0.2752, -0.3269, -0.0914,
         -0.3713],
        [-0.2647, -0.4862, -0.3420, -0.4280, -0.2659, -0.2761, -0.3208, -0.0975,
         -0.3641],
        [-0.2662, -0.4801, -0.3037, -0.4225, -0.2430, -0.2745, -0.3261, -0.0907,
         -0.3709],
        [-0.2670, -0.4799, -0.3036, -0.4222, -0.2430, -0.2735, -0.3260, -0.0912,
  

forward output:  tensor([[-0.2815, -0.4959, -0.3194, -0.4363, -0.2575, -0.2871, -0.3401, -0.0907,
         -0.3838],
        [-0.3511, -0.3988, -0.3331, -0.3201, -0.2871, -0.2866, -0.2341, -0.1430,
         -0.3777],
        [-0.2811, -0.4989, -0.3194, -0.4389, -0.2584, -0.2879, -0.3407, -0.0923,
         -0.3858],
        [-0.2801, -0.5053, -0.3572, -0.4444, -0.2810, -0.2896, -0.3333, -0.0987,
         -0.3792],
        [-0.3496, -0.4930, -0.3523, -0.3388, -0.2746, -0.2456, -0.2704, -0.0953,
         -0.3617],
        [-0.2810, -0.5006, -0.3188, -0.4389, -0.2594, -0.2889, -0.3393, -0.0918,
         -0.3845],
        [-0.3537, -0.3852, -0.2876, -0.3921, -0.3082, -0.2397, -0.2306, -0.1362,
         -0.3152],
        [-0.3535, -0.3853, -0.2869, -0.3924, -0.3090, -0.2404, -0.2310, -0.1369,
         -0.3140],
        [-0.3172, -0.5495, -0.3893, -0.4409, -0.2651, -0.2890, -0.3349, -0.0392,
         -0.3450],
        [-0.3182, -0.5491, -0.3898, -0.4397, -0.2650, -0.2882, -0.3339, -0.0384,
  

forward output:  tensor([[-0.3327, -0.5678, -0.4059, -0.4564, -0.2805, -0.3025, -0.3480, -0.0389,
         -0.3580],
        [-0.3670, -0.4146, -0.3466, -0.3349, -0.3018, -0.3014, -0.2461, -0.1442,
         -0.3917],
        [-0.2955, -0.5235, -0.3728, -0.4604, -0.2966, -0.3033, -0.3464, -0.1000,
         -0.3938],
        [-0.3169, -0.4358, -0.3129, -0.3762, -0.2627, -0.2794, -0.2672, -0.0683,
         -0.3882],
        [-0.2943, -0.5219, -0.3727, -0.4593, -0.2965, -0.3038, -0.3476, -0.0987,
         -0.3934],
        [-0.3640, -0.5107, -0.3673, -0.3529, -0.2904, -0.2589, -0.2835, -0.0966,
         -0.3760],
        [-0.2933, -0.4906, -0.2867, -0.4367, -0.2697, -0.1910, -0.2552, -0.1153,
         -0.4267],
        [-0.2750, -0.4394, -0.3178, -0.3584, -0.2922, -0.2507, -0.2650, -0.1159,
         -0.3612],
        [-0.2954, -0.5262, -0.3738, -0.4630, -0.2976, -0.3063, -0.3481, -0.0997,
         -0.3934],
        [-0.3357, -0.5802, -0.3419, -0.4801, -0.3147, -0.2666, -0.3084, -0.1025,
  

forward output:  tensor([[-0.3127, -0.5406, -0.3910, -0.4761, -0.3132, -0.3189, -0.3622, -0.1008,
         -0.4085],
        [-0.2599, -0.5124, -0.3370, -0.4792, -0.2490, -0.3026, -0.3003, -0.0871,
         -0.4333],
        [-0.3823, -0.4344, -0.3611, -0.3527, -0.3176, -0.3165, -0.2593, -0.1471,
         -0.4081],
        [-0.3481, -0.5879, -0.4220, -0.4744, -0.2973, -0.3166, -0.3631, -0.0413,
         -0.3733],
        [-0.4124, -0.4597, -0.3236, -0.3204, -0.3379, -0.3460, -0.2994, -0.1343,
         -0.3650],
        [-0.3478, -0.5848, -0.4212, -0.4722, -0.2963, -0.3154, -0.3627, -0.0404,
         -0.3734],
        [-0.3450, -0.4520, -0.3226, -0.3329, -0.2626, -0.2660, -0.2737, -0.0601,
         -0.3684],
        [-0.3452, -0.4516, -0.3234, -0.3341, -0.2629, -0.2664, -0.2741, -0.0591,
         -0.3682],
        [-0.3476, -0.5865, -0.4211, -0.4735, -0.2968, -0.3168, -0.3623, -0.0400,
         -0.3737],
        [-0.3481, -0.5882, -0.4218, -0.4745, -0.2971, -0.3166, -0.3622, -0.0405,
  

forward output:  tensor([[-0.3634, -0.6068, -0.4376, -0.4914, -0.3126, -0.3301, -0.3759, -0.0412,
         -0.3882],
        [-0.3220, -0.5230, -0.3145, -0.4684, -0.2980, -0.2139, -0.2808, -0.1178,
         -0.4539],
        [-0.3812, -0.4665, -0.3400, -0.3896, -0.2797, -0.2880, -0.2858, -0.0192,
         -0.4117],
        [-0.3470, -0.4679, -0.3411, -0.4066, -0.2929, -0.3078, -0.2916, -0.0717,
         -0.4148],
        [-0.3952, -0.5478, -0.3990, -0.3857, -0.3217, -0.2876, -0.3103, -0.0997,
         -0.4050],
        [-0.3595, -0.4668, -0.3351, -0.3481, -0.2761, -0.2785, -0.2836, -0.0602,
         -0.3821],
        [-0.3035, -0.4703, -0.3455, -0.3865, -0.3207, -0.2757, -0.2901, -0.1183,
         -0.3879],
        [-0.3597, -0.4678, -0.3354, -0.3487, -0.2767, -0.2800, -0.2848, -0.0605,
         -0.3817],
        [-0.3252, -0.5547, -0.3660, -0.4881, -0.3055, -0.3297, -0.3831, -0.0952,
         -0.4296],
        [-0.3950, -0.5459, -0.3991, -0.3837, -0.3220, -0.2869, -0.3105, -0.0991,
  

forward output:  tensor([[-0.4117, -0.5690, -0.4171, -0.4034, -0.3381, -0.3023, -0.3248, -0.1023,
         -0.4185],
        [-0.4136, -0.4697, -0.3890, -0.3834, -0.3479, -0.3453, -0.2845, -0.1530,
         -0.4390],
        [-0.3412, -0.5723, -0.3820, -0.5032, -0.3207, -0.3429, -0.3969, -0.0953,
         -0.4440],
        [-0.3409, -0.5715, -0.3819, -0.5032, -0.3201, -0.3423, -0.3972, -0.0958,
         -0.4446],
        [-0.4133, -0.4688, -0.3888, -0.3827, -0.3471, -0.3441, -0.2835, -0.1523,
         -0.4389],
        [-0.2935, -0.4666, -0.3307, -0.4668, -0.3363, -0.3433, -0.2952, -0.1712,
         -0.3978],
        [-0.3440, -0.5767, -0.4218, -0.5094, -0.3449, -0.3469, -0.3889, -0.1047,
         -0.4393],
        [-0.3795, -0.6225, -0.4537, -0.5065, -0.3277, -0.3439, -0.3905, -0.0421,
         -0.4048],
        [-0.3626, -0.4840, -0.3550, -0.4215, -0.3074, -0.3193, -0.3044, -0.0750,
         -0.4304],
        [-0.4108, -0.5654, -0.4149, -0.4013, -0.3377, -0.3019, -0.3240, -0.1009,
  

forward output:  tensor([[-0.3593, -0.5959, -0.4374, -0.5263, -0.3601, -0.3606, -0.4020, -0.1047,
         -0.4552],
        [-0.3335, -0.5016, -0.3724, -0.4164, -0.3494, -0.3014, -0.3142, -0.1217,
         -0.4170],
        [-0.3339, -0.5021, -0.3737, -0.4160, -0.3499, -0.3024, -0.3156, -0.1210,
         -0.4156],
        [-0.4247, -0.4611, -0.3493, -0.4658, -0.3796, -0.3013, -0.2883, -0.1486,
         -0.3871],
        [-0.4267, -0.5833, -0.4316, -0.4160, -0.3540, -0.3160, -0.3377, -0.1014,
         -0.4344],
        [-0.3333, -0.5058, -0.3748, -0.4196, -0.3525, -0.3052, -0.3183, -0.1221,
         -0.4159],
        [-0.3558, -0.5944, -0.3996, -0.5221, -0.3378, -0.3581, -0.4150, -0.0972,
         -0.4595],
        [-0.3078, -0.5639, -0.3842, -0.5260, -0.2941, -0.3426, -0.3398, -0.0909,
         -0.4781],
        [-0.4278, -0.5860, -0.4322, -0.4188, -0.3544, -0.3170, -0.3380, -0.1036,
         -0.4343],
        [-0.3322, -0.5033, -0.3721, -0.4171, -0.3505, -0.3015, -0.3155, -0.1222,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

Labels:  tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 

forward output:  tensor([[-0.4111, -0.6551, -0.4891, -0.5801, -0.4103, -0.4080, -0.4459, -0.1075,
         -0.5049],
        [-0.3791, -0.5517, -0.4157, -0.4623, -0.3949, -0.3442, -0.3537, -0.1235,
         -0.4593],
        [-0.4095, -0.6540, -0.4883, -0.5792, -0.4106, -0.4078, -0.4454, -0.1078,
         -0.5043],
        [-0.4226, -0.6943, -0.4989, -0.5604, -0.4411, -0.3873, -0.3920, -0.1547,
         -0.5087],
        [-0.3983, -0.6046, -0.3902, -0.5450, -0.3724, -0.2762, -0.3468, -0.1213,
         -0.5299],
        [-0.4097, -0.6537, -0.4887, -0.5798, -0.4101, -0.4091, -0.4480, -0.1064,
         -0.5055],
        [-0.4043, -0.6552, -0.4477, -0.5763, -0.3872, -0.4042, -0.4566, -0.0983,
         -0.5101],
        [-0.4192, -0.6890, -0.4955, -0.5578, -0.4390, -0.3840, -0.3921, -0.1534,
         -0.5088],
        [-0.4104, -0.6539, -0.4885, -0.5792, -0.4104, -0.4074, -0.4453, -0.1082,
         -0.5052],
        [-0.4201, -0.6869, -0.4956, -0.5560, -0.4383, -0.3833, -0.3925, -0.1531,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 

Labels:  tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

Labels:  tensor([[0., 1., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 

forward output:  tensor([[-0.5933, -0.6714, -0.5487, -0.5685, -0.5241, -0.5157, -0.4272, -0.1787,
         -0.6192],
        [-0.5493, -0.6814, -0.5063, -0.5491, -0.4666, -0.4598, -0.4433, -0.0908,
         -0.5709],
        [-0.5577, -0.8309, -0.6369, -0.6945, -0.5068, -0.5047, -0.5524, -0.0495,
         -0.5855],
        [-0.4881, -0.6716, -0.5185, -0.5725, -0.5033, -0.4470, -0.4471, -0.1346,
         -0.5637],
        [-0.5066, -0.7230, -0.4959, -0.6554, -0.4775, -0.3682, -0.4398, -0.1279,
         -0.6343],
        [-0.5474, -0.6777, -0.5037, -0.5431, -0.4625, -0.4564, -0.4391, -0.0896,
         -0.5689],
        [-0.6180, -0.6716, -0.5088, -0.5119, -0.5356, -0.5222, -0.4653, -0.1572,
         -0.5627],
        [-0.6181, -0.6697, -0.5075, -0.5114, -0.5349, -0.5218, -0.4646, -0.1570,
         -0.5619],
        [-0.5583, -0.8339, -0.6398, -0.6979, -0.5093, -0.5074, -0.5563, -0.0505,
         -0.5860],
        [-0.5269, -0.7847, -0.6061, -0.6983, -0.5253, -0.5136, -0.5465, -0.1146,
  

forward output:  tensor([[-0.6034, -0.7883, -0.6084, -0.5933, -0.5279, -0.4801, -0.4875, -0.1148,
         -0.6015],
        [-0.5675, -0.8463, -0.6489, -0.7079, -0.5178, -0.5152, -0.5630, -0.0510,
         -0.5967],
        [-0.6025, -0.6773, -0.5548, -0.5748, -0.5311, -0.5213, -0.4324, -0.1804,
         -0.6280],
        [-0.5518, -0.8425, -0.6334, -0.6931, -0.5712, -0.5065, -0.5070, -0.1656,
         -0.6389],
        [-0.5669, -0.8394, -0.6458, -0.7032, -0.5150, -0.5125, -0.5604, -0.0498,
         -0.5946],
        [-0.5367, -0.7929, -0.6169, -0.7056, -0.5357, -0.5231, -0.5569, -0.1151,
         -0.6281],
        [-0.6018, -0.6804, -0.5566, -0.5757, -0.5323, -0.5228, -0.4340, -0.1811,
         -0.6282],
        [-0.5280, -0.7969, -0.5752, -0.7008, -0.5084, -0.5177, -0.5699, -0.1018,
         -0.6326],
        [-0.6031, -0.6791, -0.5562, -0.5770, -0.5330, -0.5217, -0.4336, -0.1808,
         -0.6284],
        [-0.5670, -0.8446, -0.6482, -0.7064, -0.5172, -0.5144, -0.5623, -0.0506,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 

forward output:  tensor([[-0.5827, -0.7324, -0.5658, -0.6514, -0.5318, -0.5309, -0.4898, -0.1083,
         -0.6398],
        [-0.5682, -0.8389, -0.6176, -0.7388, -0.5475, -0.5548, -0.6083, -0.1063,
         -0.6701],
        [-0.5678, -0.8462, -0.6158, -0.7434, -0.5494, -0.5578, -0.6052, -0.1068,
         -0.6730],
        [-0.6433, -0.8352, -0.6501, -0.6346, -0.5668, -0.5184, -0.5232, -0.1192,
         -0.6385],
        [-0.5143, -0.8019, -0.5968, -0.7424, -0.5035, -0.5346, -0.5186, -0.1079,
         -0.6808],
        [-0.5344, -0.7235, -0.5620, -0.6195, -0.5493, -0.4914, -0.4858, -0.1428,
         -0.6074],
        [-0.5360, -0.7253, -0.5663, -0.6206, -0.5524, -0.4909, -0.4905, -0.1430,
         -0.6064],
        [-0.5675, -0.8413, -0.6164, -0.7408, -0.5484, -0.5548, -0.6076, -0.1074,
         -0.6724],
        [-0.6645, -0.7212, -0.5526, -0.5569, -0.5815, -0.5632, -0.5039, -0.1660,
         -0.6098],
        [-0.6393, -0.7253, -0.5931, -0.6166, -0.5714, -0.5605, -0.4670, -0.1874,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

Labels:  tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 

forward output:  tensor([[-0.7406, -0.8446, -0.6933, -0.7248, -0.6779, -0.6645, -0.5579, -0.2037,
         -0.7639],
        [-0.6959, -0.8579, -0.6539, -0.7078, -0.6227, -0.6126, -0.5821, -0.1186,
         -0.7153],
        [-0.6951, -0.8504, -0.6498, -0.7019, -0.6179, -0.6070, -0.5750, -0.1160,
         -0.7127],
        [-0.7637, -0.8236, -0.6443, -0.6482, -0.6753, -0.6493, -0.5862, -0.1801,
         -0.7064],
        [-0.6547, -0.8856, -0.6408, -0.8101, -0.6255, -0.5001, -0.5711, -0.1533,
         -0.7686],
        [-0.7492, -0.9543, -0.7588, -0.7388, -0.6721, -0.6164, -0.6174, -0.1342,
         -0.7393],
        [-0.6325, -0.8365, -0.6591, -0.7183, -0.6508, -0.5880, -0.5738, -0.1602,
         -0.6990],
        [-0.6956, -1.0015, -0.7803, -0.8332, -0.7118, -0.6374, -0.6345, -0.1783,
         -0.7814],
        [-0.6204, -0.9121, -0.7031, -0.8467, -0.6074, -0.6288, -0.6125, -0.1231,
         -0.7803],
        [-0.7597, -0.8172, -0.6380, -0.6432, -0.6709, -0.6450, -0.5820, -0.1783,
  

forward output:  tensor([[-0.7603, -0.8618, -0.7103, -0.7411, -0.6956, -0.6806, -0.5730, -0.2048,
         -0.7820],
        [-0.7643, -0.8692, -0.7159, -0.7483, -0.7016, -0.6875, -0.5783, -0.2069,
         -0.7857],
        [-0.7800, -0.8410, -0.6582, -0.6649, -0.6921, -0.6639, -0.6001, -0.1821,
         -0.7274],
        [-0.7194, -1.0324, -0.8067, -0.8600, -0.7385, -0.6614, -0.6584, -0.1836,
         -0.8076],
        [-0.7215, -1.0292, -0.8078, -0.8569, -0.7375, -0.6605, -0.6576, -0.1823,
         -0.8045],
        [-0.6535, -0.8617, -0.6805, -0.7411, -0.6723, -0.6091, -0.5937, -0.1622,
         -0.7191],
        [-0.7078, -0.9813, -0.7868, -0.8776, -0.7028, -0.6769, -0.7041, -0.1375,
         -0.7921],
        [-0.7185, -1.0287, -0.8055, -0.8562, -0.7350, -0.6591, -0.6570, -0.1803,
         -0.8050],
        [-0.7317, -0.8143, -0.6567, -0.7188, -0.6158, -0.5880, -0.5759, -0.0645,
         -0.7607],
        [-0.7032, -1.0033, -0.7557, -0.8783, -0.6869, -0.6856, -0.7296, -0.1256,
  

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 

forward output:  tensor([[-0.6627, -0.8410, -0.6523, -0.8320, -0.6858, -0.6602, -0.5889, -0.2270,
         -0.7865],
        [-0.8366, -1.0523, -0.8447, -0.8250, -0.7541, -0.6948, -0.6925, -0.1444,
         -0.8197],
        [-0.8366, -0.9035, -0.7106, -0.7200, -0.7462, -0.7174, -0.6471, -0.1893,
         -0.7817],
        [-0.7575, -0.9319, -0.7377, -0.8328, -0.7096, -0.6982, -0.6417, -0.1384,
         -0.8136],
        [-0.7636, -0.9470, -0.7501, -0.8484, -0.7222, -0.7127, -0.6548, -0.1406,
         -0.8155],
        [-0.7734, -1.0493, -0.8508, -0.9413, -0.7638, -0.7343, -0.7602, -0.1444,
         -0.8518],
        [-0.8306, -1.0451, -0.8401, -0.8147, -0.7510, -0.6896, -0.6880, -0.1414,
         -0.8156],
        [-0.8372, -0.9026, -0.7107, -0.7185, -0.7461, -0.7162, -0.6463, -0.1901,
         -0.7816],
        [-0.7457, -0.9848, -0.7250, -0.9086, -0.7147, -0.5798, -0.6498, -0.1756,
         -0.8406],
        [-0.6656, -0.8462, -0.6561, -0.8370, -0.6889, -0.6636, -0.5919, -0.2281,
  

forward output:  tensor([[-0.7960, -1.1165, -0.8870, -0.9333, -0.8137, -0.7323, -0.7282, -0.1896,
         -0.8822],
        [-0.7432, -0.9789, -0.7286, -0.8984, -0.7144, -0.5798, -0.6492, -0.1648,
         -0.8503],
        [-0.6757, -0.8525, -0.6644, -0.8450, -0.6988, -0.6704, -0.6019, -0.2298,
         -0.8010],
        [-0.7220, -0.9333, -0.7450, -0.8072, -0.7381, -0.6737, -0.6508, -0.1692,
         -0.7827],
        [-0.7724, -1.0736, -0.8294, -0.9453, -0.7524, -0.7515, -0.7968, -0.1264,
         -0.8758],
        [-0.8121, -1.1212, -0.8963, -0.9578, -0.7552, -0.7320, -0.7837, -0.0787,
         -0.8381],
        [-0.7620, -1.0028, -0.7389, -0.9274, -0.7278, -0.5953, -0.6620, -0.1770,
         -0.8519],
        [-0.8424, -1.0558, -0.8506, -0.8265, -0.7625, -0.7005, -0.6987, -0.1426,
         -0.8292],
        [-0.7827, -1.0924, -0.8356, -0.9596, -0.7647, -0.7606, -0.8006, -0.1329,
         -0.8894],
        [-0.7809, -1.0592, -0.8581, -0.9489, -0.7733, -0.7414, -0.7664, -0.1454,
  

forward output:  tensor([[-0.7371, -0.9564, -0.7617, -0.8289, -0.7565, -0.6929, -0.6674, -0.1724,
         -0.7981],
        [-0.7319, -1.0321, -0.8137, -0.9574, -0.7183, -0.7273, -0.7072, -0.1386,
         -0.8838],
        [-0.8591, -1.0788, -0.8680, -0.8475, -0.7785, -0.7171, -0.7137, -0.1476,
         -0.8429],
        [-0.7948, -0.9643, -0.7470, -0.8060, -0.7212, -0.7085, -0.6666, -0.1290,
         -0.8104],
        [-0.8538, -1.0689, -0.8623, -0.8377, -0.7735, -0.7122, -0.7087, -0.1424,
         -0.8402],
        [-0.6972, -0.8865, -0.6912, -0.8754, -0.7241, -0.6991, -0.6251, -0.2317,
         -0.8263],
        [-0.8061, -1.1232, -0.8933, -0.9401, -0.8211, -0.7386, -0.7338, -0.1898,
         -0.8901],
        [-0.7974, -1.0758, -0.8739, -0.9657, -0.7880, -0.7573, -0.7826, -0.1467,
         -0.8776],
        [-0.7357, -0.9573, -0.7640, -0.8276, -0.7571, -0.6927, -0.6697, -0.1706,
         -0.7955],
        [-0.7934, -1.1058, -0.8554, -0.9689, -0.7777, -0.7759, -0.8210, -0.1300,
  

KeyboardInterrupt: 