
# Dependencies

In [24]:
import custom_models
#python packages
from PIL import Image
from tqdm.notebook import tqdm
#from tqdm import tqdm
import gc
import datetime
import os
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from skimage import io
#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
#torchvision
import torchvision
from torchvision import datasets, models, transforms
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")

PyTorch Version:  1.2.0
Torchvision Version:  0.4.0a0+6b959ee
Using GPU!


# Initialize CNN model

In [25]:
def make_CNN(model_name, num_classes, resume_from = None):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    # The model (nn.Module) to return
    model_ft = None
    # The input image is expected to be (input_size, input_size)
    input_size = 0
    
    # You may NOT use pretrained models!! 
    use_pretrained = False
    
    # By default, all parameters will be trained (useful when you're starting from scratch)
    # Within this function you can set .requires_grad = False for various parameters, if you
    # don't want to learn them

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg13":
        """ VGG13_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg16":
        """ VGG16_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        
    elif model_name == "vgg19":
        """ VGG19_bn
        """
        model_ft = custom_models.vgg13_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224

    else:
        raise Exception("Invalid model name!")
    
    return model_ft, input_size

# Define Multimodal Model

In [26]:
class Chest_Disease_Net(nn.Module):
    """
    fc1: number of neurons in the hidden fully connected layer
    """
    def __init__(self, cnn_model_name, num_classes, num_multimodal_features=12, fc1_out=32, resume_from=None):
        #num_classes = 14
        #num_multimodal_features= 12
        super(Chest_Disease_Net, self).__init__()
        self.cnn, self.input_size = make_CNN(cnn_model_name, num_classes)#models.vgg11(pretrained=False, progress = True)
        #define output layers
        self.fc1 = nn.Linear(num_classes + num_multimodal_features, fc1_out) #takes in input of CNN and multimodal input
        self.fc2 = nn.Linear(fc1_out, num_classes)
        if resume_from is not None:
            print("Loading weights from %s" % resume_from)
            self.load_state_dict(torch.load(resume_from))
        
    def forward(self, image, data):
        x1 = self.cnn(image)
        #print("x1", x1.shape)
        x2 = data
        #print("x2", x2.shape)
        #print("x1: ", x1, type(x1))
        #print("x2: ", x2, type(x2))
        #x = torch.cat((x1, x2), dim=1)  
        x = torch.cat((x1.float(), x2.float()), dim=1) ### ???
        #print("concat", x.shape)
        x = F.relu(self.fc1(x))
        #print("relu", x.shape)
        x = self.fc2(x)
        print('forward output: ', x)
       # print("fc2", x.shape)
        return x.double() ### ???

# Data Loading

In [27]:
class MultimodalDataset(Dataset):
    """
    Custom dataset definition
    """
    def __init__(self, csv_path, img_path, transform=None):
        """
        """
        self.df = pd.read_csv(csv_path)
        self.img_path = img_path
        self.transform = transform
        self.diseases = self.get_diseases()
        
            
    def __getitem__(self, index):
        """
        """
        img_name = self.df.iloc[index]["img_name"] 
        img_path = os.path.join(self.img_path, img_name)
        image = Image.open(img_path)
        image = image.convert("RGB")
        image = np.asarray(image)
        if self.transform is not None:
            image = self.transform(image)
        dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor # ???
        features = np.fromstring(self.df.iloc[index]["feature"][1:-1], sep=",") # ???
        features = torch.from_numpy(features.astype("float")) # ???
        #label = int(self.df.iloc[index]['label'])
        labels = torch.tensor(list(self.df.iloc[index][self.diseases]), dtype = torch.float64)
        #print("Label type: ", type(label))
        #label = np.int_(label) #???
        #print("label type post casting: ", type(label))
        return image, features, labels
        
    def __len__(self):
        return len(self.df)
    
    def get_diseases(self):
        cols = list(self.df.columns)
        cols.remove('disease')
        cols.remove('feature')
        cols.remove('img_name')
        cols.remove('dataset_type')
        return cols
        

In [28]:
def get_dataloaders(input_size, batch_size, num_classes, augment=False, shuffle = True):
    # How to transform the image when you are loading them.
    # you'll likely want to mess with the transforms on the training set.
    
    # For now, we resize/crop the image to the correct input size for our network,
    # then convert it to a [C,H,W] tensor, then normalize it to values with a given mean/stdev. These normalization constants
    # are derived from aggregating lots of data and happen to produce better results.
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            #Add extra transformations for data augmentation
            transforms.RandomApply([
                transforms.RandomChoice([
                    transforms.RandomAffine(degrees=20),
                    transforms.RandomAffine(degrees=0,scale=(0.1, 0.15)),
                    transforms.RandomAffine(degrees=0,translate=(0.2,0.2)),
                    #transforms.RandomAffine(degrees=0,shear=0.15),
                    transforms.RandomHorizontalFlip(p=1.0)
                ] if augment else [transforms.RandomAffine(degrees=0)])#else do nothing
            ], p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ]),
        'test': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.225])
        ])
    }
    # Create training and validation datasets
    data_subsets = {x: MultimodalDataset(csv_path="./data/"+x+"_dataset{}.csv".format(num_classes), 
                                         img_path="/storage/images", 
                                         transform=data_transforms[x]) for x in data_transforms.keys()}
    # Create training and validation dataloaders
    # Never shuffle the test set
    dataloaders_dict = {x: DataLoader(data_subsets[x], batch_size=batch_size, shuffle=False if x != 'train' else shuffle, num_workers=4) for x in data_transforms.keys()}
    return dataloaders_dict

# Training

In [36]:
def train_model(model, multimodal, dataloaders, criterion, optimizer, scheduler, model_name=str(datetime.datetime.now()), 
                save_dir = None, save_all_epochs=False, num_epochs=25):
    '''
    model: The NN to train
    dataloaders: A dictionary containing at least the keys 
                 'train','val' that maps to Pytorch data loaders for the dataset
    criterion: The Loss function
    optimizer: The algorithm to update weights 
               (Variations on gradient descent)
    num_epochs: How many epochs to train for
    save_dir: Where to save the best model weights that are found, 
              as they are found. Will save to save_dir/weights_best.pt
              Using None will not write anything to disk
    save_all_epochs: Whether to save weights for ALL epochs, not just the best
                     validation error epoch. Will save to save_dir/weights_e{#}.pt
    '''
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            # TQDM has nice progress bars
            for inputs, features, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                features = features.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    if multimodal:
                        outputs = model(inputs, features)
                    else:
                        outputs = model(inputs).type(torch.float64)
                    print("model outputs: ", outputs, outputs.size())
                    #print("model labels: ", labels.size())
                    loss = criterion(outputs, labels)

                    # torch.max outputs the maximum value, and its index
                    # Since the input is batched, we take the max along axis 1
                    # (the meaningful outputs)
                    #print("outputs: ", outputs)
                    #preds = torch.max(outputs, 1)
                    preds = (outputs > 0).type(torch.float64)
                    #_, preds = torch.max(outputs, 1)
                    #print("new preds: ", preds)
                    # backprop + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("loss: ", loss.item())
                #print("inputs: ", inputs.size(0), inputs.size())
                
                running_loss += loss.item() * inputs.size(0)
                #print("running loss: ", running_loss)
                print("predictions: ", preds, preds.size())
                print("Labels: ", labels.data, labels.data.size())
                running_corrects += torch.sum(preds == labels.data)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            #is the accuracy calculated correctly?
            print("running_corrects: ", running_corrects.double(), running_corrects.size())
            print("dataloaders len: ", len(dataloaders[phase].dataset)) 
            #maybe dataloaders length * number of classes? Model must predict for all classes
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), save_dir + "/{}_best_weights_1.pt".format(model_name))
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        print()
        scheduler.step()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

# Optimizer & Loss

In [30]:
def make_optimizer(model):
    # Get all the parameters
    params_to_update = model.parameters()
    print("Params to learn:")
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

    # Use SGD
    optimizer = optim.SGD(params_to_update, lr=0.01, momentum=0.9)
    return optimizer

def get_loss(num_classes,device):
    # Create an instance of the loss function
    #set weights to account for unbalanced data.
    #In expectation every category class contributes the same to the loss
    pos_weight = np.array([    0.025944895136267, 0.059170436277992,0.049543908183484,
                            0.124453250588807,0.013364985606939,0.185904145949381,
                            0.108022729821676,0.564067815619275,0.054011364910838])
    pos_weight = torch.tensor(pos_weight,dtype=torch.float64) if num_classes == 9 \
    else torch.tensor([])
    pos_weight = pos_weight.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    return criterion

# Define Experiment Parameters

In [34]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet]
# You can add your own, or modify these however you wish!
model_name = "alexnet"

# Number of classes in the dataset
# Miniplaces has 100
num_classes = 9# set to 9 or 14

# Batch size for training (change depending on how much memory you have)
# You should use a power of 2.
batch_size = 64

# Shuffle the input data?
shuffle_datasets = True

# Number of epochs to train for 
num_epochs = 10

### IO
# Path to a model file to use to start weights at
#resume_from = "/home/ubuntu/6.867-xray-project/weights/data_aug_vgg.pt"
resume_from = None

# Directory to save weights to
save_dir = "weights"
os.makedirs(save_dir, exist_ok=True)

# If True saves the weights for all epochs, else only saves the weight of best one
save_all_epochs = False

#Enable multimodal
multimodal = False

# Train!

In [32]:
gc.collect()

990

In [37]:
# Initialize the model for this run
if multimodal:
    model = Chest_Disease_Net(cnn_model_name = model_name, num_classes = num_classes, resume_from = resume_from)
    input_size = model.input_size
else:
    model, input_size = make_CNN(model_name=model_name,num_classes=num_classes, resume_from = resume_from)
    
dataloaders = get_dataloaders(input_size, batch_size, num_classes, shuffle_datasets)
criterion = get_loss(num_classes=num_classes,device=device)

# Move the model to the gpu if needed
model = model.to(device)

optimizer = make_optimizer(model)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10],gamma=0.1)
# Train the model!
trained_model, validation_history = train_model(model=model, multimodal=multimodal,
                                                dataloaders=dataloaders, criterion=criterion, optimizer=optimizer,
            scheduler=scheduler, model_name=model_name, save_dir=save_dir, save_all_epochs=save_all_epochs, num_epochs=num_epochs)

Params to learn:
	 features.0.weight
	 features.0.bias
	 features.3.weight
	 features.3.bias
	 features.6.weight
	 features.6.bias
	 features.8.weight
	 features.8.bias
	 features.10.weight
	 features.10.bias
	 classifier.1.weight
	 classifier.1.bias
	 classifier.4.weight
	 classifier.4.bias
	 classifier.6.weight
	 classifier.6.bias
Epoch 0/9
----------


HBox(children=(IntProgress(value=0, max=1005), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38010fe828>>
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 926, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 906, in _shutdown_workers
    w.join()
  File "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38010fe828>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.d

model outputs:  tensor([[-2.0565e-02,  2.6612e-03, -3.3670e-03, -1.2585e-02,  2.5765e-02,
         -6.6867e-03, -1.0234e-02,  2.2622e-02, -2.6087e-03],
        [-1.3005e-02,  4.3028e-03, -3.9609e-03,  2.9693e-03,  1.0114e-02,
         -6.2731e-03, -3.4548e-02,  1.7021e-02, -1.1387e-04],
        [-1.7200e-02,  1.2711e-03, -3.2782e-03, -1.2602e-02,  2.1503e-02,
         -4.7971e-03, -1.1838e-02,  1.1922e-02,  7.9263e-04],
        [-1.9981e-02, -2.2260e-04,  5.3618e-03, -8.8174e-03,  1.2821e-02,
         -9.2966e-03, -1.3570e-02,  1.8350e-02, -4.6068e-03],
        [-1.2018e-02,  1.7674e-02, -1.5926e-03, -3.6958e-03,  2.0728e-02,
         -2.3673e-03, -3.0524e-03,  1.3564e-02,  2.0132e-03],
        [-6.4156e-03,  2.4536e-03,  1.1266e-02, -1.4409e-02,  1.7443e-02,
         -1.2461e-02, -1.0097e-02,  1.0281e-02,  4.1139e-03],
        [-1.3825e-02,  4.9947e-03, -2.5706e-03, -9.6975e-03,  1.7101e-02,
         -1.1029e-02, -8.3604e-03,  1.5776e-02,  4.8307e-03],
        [-2.0158e-02,  1.7359e-0

model outputs:  tensor([[-2.1693e-02,  1.1289e-03, -2.6041e-04, -2.2622e-02,  1.0383e-02,
         -6.9136e-03, -8.3866e-03,  1.4565e-02, -5.1682e-03],
        [-1.7584e-02,  2.7874e-03,  4.0967e-03, -1.2023e-02,  1.2799e-02,
          1.4302e-03, -1.2700e-02,  1.7316e-02, -2.1172e-03],
        [-1.0435e-02, -1.5543e-03,  6.5548e-04, -1.0702e-02,  1.2984e-02,
         -1.5426e-02, -1.0116e-02,  2.0192e-02, -5.7376e-03],
        [-1.8736e-02, -2.1874e-03, -3.7227e-05, -1.9706e-02,  1.7119e-02,
         -6.5886e-03, -9.1857e-03,  1.3764e-02, -7.6683e-03],
        [-2.3306e-02, -3.0764e-04, -7.6023e-03, -1.0897e-02,  1.1672e-02,
         -8.7115e-03, -1.6854e-02,  1.0878e-02, -1.7788e-03],
        [-2.2191e-02,  9.1034e-03,  1.6501e-06, -1.9022e-02,  8.1750e-03,
         -6.4274e-03, -1.7283e-02,  1.8348e-02,  5.2670e-03],
        [-1.8590e-02, -6.8796e-04, -1.0882e-03, -1.4048e-02,  1.4789e-02,
         -4.4947e-03, -8.9878e-03,  2.2418e-02,  1.9304e-03],
        [-1.9534e-02,  1.2054e-0

model outputs:  tensor([[-0.0210, -0.0111,  0.0014, -0.0149,  0.0055, -0.0141, -0.0198,  0.0185,
         -0.0022],
        [-0.0257, -0.0077, -0.0093, -0.0154,  0.0048, -0.0184, -0.0188,  0.0149,
         -0.0025],
        [-0.0211, -0.0027, -0.0122, -0.0157,  0.0123, -0.0166, -0.0242,  0.0208,
         -0.0059],
        [-0.0213, -0.0072, -0.0095, -0.0207,  0.0085, -0.0126, -0.0187,  0.0184,
         -0.0060],
        [-0.0229, -0.0012, -0.0107, -0.0246,  0.0112, -0.0116, -0.0181,  0.0154,
         -0.0060],
        [-0.0206, -0.0013, -0.0039, -0.0080,  0.0038, -0.0144, -0.0186,  0.0137,
         -0.0072],
        [-0.0223, -0.0015, -0.0064, -0.0190,  0.0040, -0.0168, -0.0151,  0.0084,
         -0.0100],
        [-0.0238, -0.0046,  0.0009, -0.0227,  0.0125, -0.0079, -0.0119,  0.0104,
         -0.0046],
        [-0.0307, -0.0151, -0.0176, -0.0217,  0.0172, -0.0152, -0.0153,  0.0193,
         -0.0183],
        [-0.0219, -0.0116, -0.0024, -0.0253,  0.0141, -0.0155, -0.0220,  0.0101,
   

model outputs:  tensor([[-3.6148e-02, -1.4160e-02, -1.0174e-02, -3.0629e-02, -9.1057e-03,
         -2.8325e-02, -3.2541e-02,  1.0835e-02, -2.9921e-02],
        [-3.6362e-02, -2.0054e-02, -1.0302e-02, -2.5233e-02, -8.6301e-03,
         -2.4671e-02, -2.2850e-02,  1.7065e-02, -2.1689e-02],
        [-3.9039e-02, -1.5335e-02, -1.2719e-02, -3.7395e-02, -1.9499e-03,
         -2.3811e-02, -2.0155e-02,  1.1734e-02, -2.1848e-02],
        [-2.7664e-02, -8.9511e-03, -1.1633e-02, -2.8115e-02, -6.9095e-03,
         -1.9804e-02, -2.0924e-02,  9.2933e-03, -9.6610e-03],
        [-3.3891e-02, -1.6187e-02, -1.2098e-02, -2.5215e-02, -1.1769e-02,
         -1.9269e-02, -2.6779e-02,  6.5076e-03, -2.1151e-02],
        [-3.4659e-02, -6.1973e-03, -1.2354e-02, -3.1250e-02, -3.5433e-03,
         -1.9685e-02, -2.4521e-02,  1.0481e-02, -1.3187e-02],
        [-3.5442e-02, -9.6865e-03, -1.5940e-02, -3.0761e-02, -4.0807e-03,
         -2.6922e-02, -2.0681e-02,  1.0992e-02, -1.9003e-02],
        [-2.4984e-02, -7.3213e-0

model outputs:  tensor([[-0.0323, -0.0223, -0.0244, -0.0364, -0.0067, -0.0230, -0.0300,  0.0094,
         -0.0250],
        [-0.0377, -0.0278, -0.0236, -0.0344, -0.0107, -0.0261, -0.0288,  0.0156,
         -0.0344],
        [-0.0342, -0.0216, -0.0280, -0.0392, -0.0155, -0.0259, -0.0370,  0.0086,
         -0.0210],
        [-0.0413, -0.0245, -0.0189, -0.0250, -0.0116, -0.0247, -0.0350,  0.0108,
         -0.0255],
        [-0.0388, -0.0242, -0.0207, -0.0353, -0.0113, -0.0282, -0.0346,  0.0094,
         -0.0266],
        [-0.0365, -0.0193, -0.0244, -0.0316, -0.0142, -0.0375, -0.0293,  0.0064,
         -0.0303],
        [-0.0439, -0.0201, -0.0223, -0.0281, -0.0138, -0.0214, -0.0184,  0.0093,
         -0.0299],
        [-0.0381, -0.0226, -0.0311, -0.0322, -0.0220, -0.0292, -0.0346,  0.0117,
         -0.0229],
        [-0.0305, -0.0270, -0.0192, -0.0295, -0.0115, -0.0295, -0.0319,  0.0112,
         -0.0265],
        [-0.0334, -0.0216, -0.0264, -0.0411, -0.0160, -0.0276, -0.0366,  0.0105,
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-0.0634, -0.0521, -0.0459, -0.0572, -0.0400, -0.0447, -0.0535,  0.0111,
         -0.0366],
        [-0.0592, -0.0489, -0.0448, -0.0483, -0.0342, -0.0462, -0.0518,  0.0110,
         -0.0440],
        [-0.0597, -0.0510, -0.0493, -0.0617, -0.0412, -0.0525, -0.0512,  0.0079,
         -0.0466],
        [-0.0627, -0.0459, -0.0570, -0.0624, -0.0416, -0.0473, -0.0618,  0.0146,
         -0.0551],
        [-0.0631, -0.0510, -0.0547, -0.0543, -0.0313, -0.0490, -0.0606,  0.0133,
         -0.0508],
        [-0.0580, -0.0470, -0.0551, -0.0604, -0.0379, -0.0501, -0.0472,  0.0122,
         -0.0523],
        [-0.0715, -0.0461, -0.0397, -0.0558, -0.0383, -0.0477, -0.0652,  0.0017,
         -0.0536],
        [-0.0681, -0.0470, -0.0534, -0.0671, -0.0355, -0.0450, -0.0512,  0.0142,
         -0.0465],
        [-0.0650, -0.0522, -0.0539, -0.0607, -0.0274, -0.0511, -0.0499,  0.0132,
         -0.0477],
        [-0.0738, -0.0493, -0.0350, -0.0562, -0.0355, -0.0525, -0.0641,  0.0081,
   

model outputs:  tensor([[-0.0834, -0.0590, -0.0570, -0.0723, -0.0456, -0.0591, -0.0682,  0.0071,
         -0.0590],
        [-0.0821, -0.0646, -0.0708, -0.0711, -0.0548, -0.0576, -0.0658,  0.0057,
         -0.0621],
        [-0.0769, -0.0522, -0.0507, -0.0605, -0.0485, -0.0581, -0.0632,  0.0079,
         -0.0574],
        [-0.0848, -0.0623, -0.0543, -0.0682, -0.0503, -0.0630, -0.0742,  0.0128,
         -0.0596],
        [-0.0747, -0.0649, -0.0552, -0.0767, -0.0525, -0.0659, -0.0744,  0.0089,
         -0.0708],
        [-0.0695, -0.0572, -0.0595, -0.0698, -0.0485, -0.0585, -0.0608,  0.0107,
         -0.0625],
        [-0.0897, -0.0627, -0.0555, -0.0747, -0.0598, -0.0718, -0.0728, -0.0020,
         -0.0769],
        [-0.0783, -0.0640, -0.0603, -0.0597, -0.0478, -0.0576, -0.0659,  0.0006,
         -0.0653],
        [-0.0872, -0.0627, -0.0602, -0.0681, -0.0685, -0.0633, -0.0752,  0.0175,
         -0.0760],
        [-0.0793, -0.0584, -0.0558, -0.0685, -0.0496, -0.0648, -0.0676,  0.0042,
   

model outputs:  tensor([[-9.7929e-02, -7.4431e-02, -7.3893e-02, -8.5012e-02, -7.5169e-02,
         -7.0609e-02, -8.3265e-02,  1.4911e-02, -7.3196e-02],
        [-1.0673e-01, -7.5092e-02, -8.0520e-02, -9.1211e-02, -7.0359e-02,
         -8.2130e-02, -9.2941e-02,  7.6744e-03, -8.7719e-02],
        [-9.9920e-02, -7.5661e-02, -6.7625e-02, -8.1104e-02, -6.8777e-02,
         -7.3511e-02, -8.8372e-02,  1.0520e-02, -7.9790e-02],
        [-1.0381e-01, -8.0453e-02, -7.6967e-02, -9.2915e-02, -7.7968e-02,
         -7.8031e-02, -8.2161e-02, -4.5588e-03, -8.4118e-02],
        [-1.1714e-01, -9.0292e-02, -7.8138e-02, -9.7636e-02, -8.3872e-02,
         -7.1598e-02, -8.4552e-02,  1.5220e-02, -8.6083e-02],
        [-1.0318e-01, -8.9621e-02, -8.1460e-02, -8.9898e-02, -7.2321e-02,
         -8.0818e-02, -7.7703e-02,  1.7464e-03, -9.4836e-02],
        [-9.9529e-02, -7.0669e-02, -7.1367e-02, -8.2678e-02, -6.6819e-02,
         -7.0160e-02, -7.8624e-02,  1.9163e-03, -7.0489e-02],
        [-9.3319e-02, -7.7833e-0

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-0.1336, -0.1114, -0.1182, -0.1120, -0.1145, -0.0964, -0.1274,  0.0101,
         -0.1220],
        [-0.1364, -0.1158, -0.1116, -0.1254, -0.1100, -0.1224, -0.1210, -0.0061,
         -0.1261],
        [-0.1226, -0.1105, -0.1075, -0.1194, -0.0996, -0.1008, -0.1044, -0.0007,
         -0.1194],
        [-0.1391, -0.1311, -0.1290, -0.1318, -0.1223, -0.1058, -0.1285,  0.0023,
         -0.1072],
        [-0.1237, -0.1010, -0.1040, -0.1069, -0.1010, -0.1005, -0.1083,  0.0010,
         -0.0985],
        [-0.1236, -0.1037, -0.1021, -0.1104, -0.1020, -0.1021, -0.1128, -0.0044,
         -0.1140],
        [-0.1269, -0.1029, -0.1179, -0.1276, -0.1067, -0.0978, -0.1174, -0.0030,
         -0.1140],
        [-0.1618, -0.1395, -0.1263, -0.1461, -0.1358, -0.1118, -0.1363, -0.0103,
         -0.1679],
        [-0.1399, -0.1390, -0.1501, -0.1497, -0.1360, -0.1459, -0.1378, -0.0125,
         -0.1395],
        [-0.1327, -0.1028, -0.1128, -0.1106, -0.0883, -0.1025, -0.1023, -0.0029,
   

model outputs:  tensor([[-1.6242e-01, -1.4755e-01, -1.4733e-01, -1.5050e-01, -1.3333e-01,
         -1.3164e-01, -1.5218e-01, -3.5271e-03, -1.4087e-01],
        [-1.6904e-01, -1.4286e-01, -1.5058e-01, -1.4767e-01, -1.3436e-01,
         -1.4239e-01, -1.6284e-01, -4.8614e-04, -1.4731e-01],
        [-1.4852e-01, -1.2578e-01, -1.2390e-01, -1.2488e-01, -1.1730e-01,
         -1.1257e-01, -1.3013e-01, -7.7331e-03, -1.3405e-01],
        [-1.8710e-01, -1.6839e-01, -1.6625e-01, -1.5920e-01, -1.5925e-01,
         -1.3985e-01, -1.6453e-01,  1.8499e-03, -1.5588e-01],
        [-1.6652e-01, -1.4395e-01, -1.3713e-01, -1.4044e-01, -1.3404e-01,
         -1.3516e-01, -1.2920e-01,  1.4674e-03, -1.3218e-01],
        [-1.8033e-01, -1.4001e-01, -1.5034e-01, -1.5377e-01, -1.4285e-01,
         -1.4658e-01, -1.5993e-01, -1.4560e-02, -1.5308e-01],
        [-1.5919e-01, -1.3793e-01, -1.4118e-01, -1.3966e-01, -1.4383e-01,
         -1.2130e-01, -1.4508e-01,  1.9952e-03, -1.5794e-01],
        [-1.5469e-01, -1.5599e-0

model outputs:  tensor([[-1.9258e-01, -1.8029e-01, -1.8290e-01, -1.7324e-01, -1.8372e-01,
         -1.6217e-01, -1.9000e-01, -2.0049e-02, -1.7633e-01],
        [-1.6785e-01, -1.4549e-01, -1.5132e-01, -1.5732e-01, -1.4780e-01,
         -1.3686e-01, -1.5792e-01, -4.6313e-03, -1.4700e-01],
        [-1.8648e-01, -1.7463e-01, -1.8208e-01, -1.7845e-01, -1.7890e-01,
         -1.6065e-01, -1.7132e-01, -1.0553e-02, -1.8703e-01],
        [-1.7629e-01, -1.4719e-01, -1.5794e-01, -1.5934e-01, -1.5148e-01,
         -1.4231e-01, -1.5557e-01, -6.2260e-03, -1.6769e-01],
        [-2.3226e-01, -1.8871e-01, -1.8464e-01, -1.8883e-01, -1.7549e-01,
         -1.8294e-01, -1.8087e-01, -3.6178e-03, -1.7512e-01],
        [-1.9203e-01, -1.7123e-01, -1.6782e-01, -1.5973e-01, -1.6928e-01,
         -1.5777e-01, -1.5980e-01, -1.5403e-02, -1.6535e-01],
        [-1.9566e-01, -1.9930e-01, -1.7202e-01, -1.8464e-01, -1.9407e-01,
         -1.7489e-01, -1.6523e-01, -1.1966e-02, -1.7720e-01],
        [-1.9079e-01, -1.6198e-0

model outputs:  tensor([[-0.2717, -0.2913, -0.2955, -0.2436, -0.2935, -0.2473, -0.2907, -0.0259,
         -0.2813],
        [-0.3197, -0.2506, -0.2630, -0.2742, -0.2725, -0.2364, -0.2377, -0.0020,
         -0.2988],
        [-0.3334, -0.2925, -0.2962, -0.2919, -0.3124, -0.2560, -0.3154, -0.0485,
         -0.3250],
        [-0.4485, -0.3933, -0.4402, -0.3926, -0.4832, -0.3677, -0.3524, -0.0221,
         -0.4777],
        [-0.3033, -0.2740, -0.2913, -0.2779, -0.3050, -0.2631, -0.2418,  0.0030,
         -0.3096],
        [-0.3283, -0.3211, -0.3145, -0.2992, -0.3571, -0.2945, -0.2868, -0.0463,
         -0.3274],
        [-0.2796, -0.2609, -0.2579, -0.2958, -0.2481, -0.2334, -0.2198, -0.0212,
         -0.2670],
        [-0.2485, -0.2264, -0.2373, -0.2272, -0.2450, -0.2175, -0.2263, -0.0104,
         -0.2087],
        [-0.4684, -0.4685, -0.4074, -0.4417, -0.5245, -0.3466, -0.4238, -0.0316,
         -0.4645],
        [-0.2886, -0.2502, -0.2221, -0.2443, -0.2539, -0.2306, -0.2335, -0.0275,
   

model outputs:  tensor([[-2.7185e-01, -2.5550e-01, -2.3568e-01, -2.4170e-01, -2.6546e-01,
         -2.4427e-01, -2.6937e-01, -3.9854e-02, -2.7160e-01],
        [-3.3633e-01, -3.2718e-01, -2.9972e-01, -3.0693e-01, -3.1903e-01,
         -2.4801e-01, -2.9934e-01, -2.7412e-02, -3.2944e-01],
        [-3.8100e-01, -3.5484e-01, -3.3187e-01, -3.7216e-01, -3.8285e-01,
         -3.4068e-01, -3.5110e-01, -2.8255e-02, -3.6988e-01],
        [-3.3616e-01, -2.9246e-01, -2.8535e-01, -2.9666e-01, -3.3016e-01,
         -2.5445e-01, -3.1203e-01, -2.8380e-02, -2.9566e-01],
        [-2.8403e-01, -2.4585e-01, -2.4538e-01, -2.6575e-01, -2.5860e-01,
         -2.2316e-01, -2.2479e-01, -1.4005e-02, -2.5857e-01],
        [-3.2206e-01, -3.1189e-01, -3.2937e-01, -3.0760e-01, -3.3935e-01,
         -2.7575e-01, -2.9460e-01, -1.2757e-02, -3.1024e-01],
        [-3.5565e-01, -3.4931e-01, -3.2550e-01, -3.6110e-01, -3.8035e-01,
         -2.7138e-01, -3.1916e-01, -3.7951e-02, -3.7599e-01],
        [-4.3864e-01, -4.0812e-0

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-2.0901, -2.3279, -2.2153, -2.1605, -2.4814, -1.5315, -2.0673, -0.2523,
         -2.3621],
        [-1.2027, -1.0712, -1.0134, -1.0911, -1.2084, -0.9579, -1.0373, -0.0100,
         -1.1519],
        [-0.9064, -0.9094, -0.7518, -0.8627, -0.9185, -0.7106, -0.7698, -0.1028,
         -0.9241],
        [-2.1662, -2.2444, -2.3382, -2.1516, -2.2770, -1.9623, -2.1256, -0.0953,
         -2.2961],
        [-1.1984, -1.1671, -1.0326, -1.0875, -1.2544, -0.9592, -0.9972, -0.1713,
         -1.0697],
        [-0.7765, -0.7506, -0.7190, -0.6900, -0.7798, -0.6502, -0.6965, -0.1116,
         -0.8584],
        [-1.3646, -1.3650, -1.2504, -1.2213, -1.4283, -1.1775, -1.3147, -0.1726,
         -1.4473],
        [-1.1937, -0.9901, -1.0292, -1.1238, -1.1508, -0.9121, -0.9786, -0.1246,
         -1.1537],
        [-0.9509, -0.9220, -0.8503, -0.8417, -0.9680, -0.7729, -0.7869, -0.0385,
         -0.9103],
        [-1.4395, -1.3552, -1.3888, -1.4621, -1.4430, -1.0949, -1.3400, -0.0574,
   

Labels:  tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 

model outputs:  tensor([[ -5.1088,  -5.0478,  -4.9435,  -4.8466,  -5.1677,  -3.7124,  -4.5782,
          -0.5327,  -4.8648],
        [ -8.8962,  -8.1891,  -8.6223,  -8.0284,  -9.1620,  -7.0906,  -7.3549,
          -0.7968,  -8.4773],
        [ -4.2790,  -4.1210,  -3.8331,  -3.8264,  -4.2838,  -3.4369,  -3.7892,
          -0.7296,  -4.1792],
        [ -8.4507,  -8.8726,  -8.6259,  -7.6444,  -8.4257,  -6.4105,  -7.6623,
          -0.9135,  -8.8930],
        [ -3.5390,  -3.5735,  -3.5072,  -3.4925,  -3.7217,  -2.7663,  -3.0476,
          -0.3672,  -3.7196],
        [ -2.8704,  -2.7528,  -2.6847,  -2.7819,  -2.9460,  -2.2369,  -2.5620,
          -0.3100,  -2.8638],
        [-15.4040, -14.7304, -15.3311, -13.9307, -15.5956, -11.7809, -13.4785,
          -1.1570, -13.5351],
        [ -6.3782,  -6.4275,  -6.4879,  -5.9037,  -6.9580,  -5.3531,  -5.8675,
          -0.7173,  -6.5761],
        [ -8.5176,  -8.2497,  -7.6682,  -7.2467,  -8.5712,  -6.4842,  -7.3141,
          -1.0011,  -8.6327],
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-42.9854, -39.4234, -39.0594, -36.4426, -44.0205, -27.4176, -35.6212,
          -2.4690, -42.1854],
        [-13.2367, -12.6771, -13.0748, -11.5146, -12.6916,  -9.4194, -12.3829,
          -0.6113, -12.8843],
        [-22.5432, -21.6972, -21.2310, -19.5916, -23.6170, -14.9235, -18.4339,
          -1.0278, -22.1197],
        [-16.8937, -16.0231, -15.0757, -14.9101, -17.8713, -10.8122, -13.7217,
          -0.4280, -16.6912],
        [ -6.1378,  -5.8583,  -5.6683,  -5.5053,  -6.3611,  -3.9400,  -5.3783,
          -0.3050,  -6.0205],
        [-20.0596, -21.2864, -20.5418, -18.4347, -20.8625, -14.1671, -17.3671,
          -1.8105, -20.4867],
        [-18.3405, -18.8487, -18.5600, -17.0465, -20.0225, -11.8956, -16.2014,
          -1.2722, -19.9135],
        [-25.7728, -24.3941, -24.9432, -22.5240, -27.7636, -17.4341, -22.8992,
          -2.9904, -25.8906],
        [-12.8658, -12.4042, -12.1840, -10.6328, -12.5208,  -7.9320, -10.4343,
          -0.1983, -12.1293],
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-25.9348, -26.4325, -25.1977, -20.1466, -28.5340, -10.0414, -21.0034,
          -1.3408, -24.6666],
        [-12.9563, -11.7442, -12.4262, -10.0429, -13.4499,  -4.0828, -10.0656,
          -0.6701, -12.5821],
        [-15.6714, -14.7869, -14.7090, -12.4243, -15.6656,  -5.6775, -11.4877,
          -1.2551, -14.1953],
        [-23.3869, -23.6273, -21.7820, -18.0999, -23.4164,  -7.0232, -17.3301,
          -1.5847, -21.9933],
        [ -9.5287,  -9.0452,  -9.1183,  -6.9906, -10.0508,  -3.3475,  -7.3576,
          -0.1189,  -9.6306],
        [-14.0845, -14.1818, -13.9126, -11.6784, -14.5318,  -4.4384, -11.7917,
          -0.4277, -14.3945],
        [-18.2368, -17.2768, -16.4307, -13.2879, -19.4230,  -6.2302, -13.9645,
          -1.3536, -18.0148],
        [-25.8933, -24.5475, -24.2701, -19.9039, -25.9320,  -8.7414, -19.7175,
          -1.2514, -24.3290],
        [-28.6412, -27.2429, -27.3466, -22.0685, -29.2946, -10.1258, -22.5349,
          -1.2545, -27.5577],
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-10.2694,  -9.3435,  -9.0186,  -5.2564, -10.4592,  -0.4292,  -6.7055,
          -0.7858,  -9.5885],
        [ -5.2439,  -4.8163,  -5.2007,  -3.0283,  -5.7253,  -0.1532,  -3.7869,
          -0.2082,  -5.0609],
        [-12.9198, -11.5406, -12.3633,  -7.6025, -13.2349,  -0.5956,  -8.8825,
          -1.0236, -12.8090],
        [-11.2808, -10.4841, -10.5942,  -6.6465, -12.1260,  -0.6502,  -7.6396,
          -1.9336, -11.2060],
        [-13.6359, -12.7032, -12.4287,  -7.8051, -14.2225,  -0.6675,  -8.8033,
          -1.1332, -12.5904],
        [-10.4202,  -9.3989,  -9.5397,  -6.5740, -11.0769,  -1.1348,  -6.8913,
          -0.9194, -10.1225],
        [ -9.3647,  -8.3107,  -8.1660,  -5.1367,  -9.2774,  -0.6417,  -5.8699,
          -0.7640,  -8.6246],
        [-23.6651, -22.9709, -22.7717, -13.3096, -25.0860,  -1.6052, -16.1889,
          -1.2424, -23.7263],
        [-16.6041, -15.2753, -15.8904, -10.3307, -17.5721,  -1.2352, -12.0414,
          -1.9803, -16.4424],
   

model outputs:  tensor([[-10.6610, -10.0297,  -9.8404,  -6.1581, -11.6944,  -1.8068,  -6.8511,
           0.1748, -10.3763],
        [-19.9346, -18.3083, -20.0638,  -9.6040, -21.1328,  -3.4189, -13.5082,
           0.7918, -19.0296],
        [ -5.4152,  -4.8099,  -4.8393,  -2.7495,  -5.7360,  -1.0705,  -3.0841,
           0.2199,  -4.8226],
        [-11.3543,  -9.9754, -10.6385,  -4.9300, -11.7673,  -2.7713,  -6.9571,
           0.7658, -10.4151],
        [ -7.9387,  -7.1324,  -7.0992,  -3.9507,  -8.6964,  -1.4766,  -4.4599,
           0.0843,  -7.3528],
        [-10.2129,  -8.9989,  -8.9277,  -5.1527, -10.5465,  -2.1984,  -6.2340,
           0.0933,  -9.6147],
        [-10.1491,  -9.2937,  -9.7383,  -4.9340, -10.8027,  -2.0974,  -6.5766,
           0.0847,  -9.5318],
        [ -9.1980,  -7.9203,  -8.1236,  -3.7521,  -9.2447,  -1.4976,  -5.4027,
          -0.1182,  -8.1384],
        [ -7.2950,  -7.0431,  -6.5970,  -3.5217,  -7.7794,  -1.6143,  -4.8303,
           0.2926,  -6.8549],
   

model outputs:  tensor([[-4.1324e+00, -3.5594e+00, -3.4052e+00, -1.7612e+00, -4.3107e+00,
         -1.3604e+00, -2.1998e+00,  1.2004e-02, -3.6288e+00],
        [-8.0806e+00, -7.0660e+00, -7.1026e+00, -2.8084e+00, -8.4828e+00,
         -2.9932e+00, -4.5270e+00,  3.7564e-02, -7.0468e+00],
        [-4.2001e+00, -3.7685e+00, -3.6239e+00, -1.8823e+00, -4.2592e+00,
         -1.5813e+00, -2.3074e+00,  1.5970e-01, -3.7668e+00],
        [-7.1805e+00, -5.9108e+00, -6.6956e+00, -2.9223e+00, -7.6987e+00,
         -2.4774e+00, -3.7506e+00,  2.4386e-01, -6.2046e+00],
        [-1.0227e+01, -8.6638e+00, -9.1491e+00, -4.1301e+00, -1.0961e+01,
         -3.7377e+00, -5.7504e+00,  1.3293e-01, -9.5735e+00],
        [-9.9739e+00, -8.7773e+00, -8.9207e+00, -4.5611e+00, -1.0833e+01,
         -3.4772e+00, -5.8807e+00,  3.3281e-01, -9.4165e+00],
        [-1.0332e+01, -8.4990e+00, -9.0728e+00, -4.3085e+00, -1.0786e+01,
         -3.5267e+00, -5.9520e+00, -2.3157e-01, -9.3406e+00],
        [-3.8297e+00, -3.4556e+0

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[ -7.7740,  -6.4227,  -6.6306,  -2.9365,  -8.0104,  -4.6534,  -3.5380,
          -0.8311,  -6.8364],
        [ -4.5416,  -3.9375,  -3.9336,  -1.8483,  -5.2074,  -2.8927,  -2.6524,
          -0.6179,  -4.3438],
        [ -4.6916,  -3.8381,  -4.0449,  -1.8433,  -4.8023,  -2.8419,  -2.3476,
          -0.5519,  -4.0832],
        [ -5.2296,  -4.6203,  -4.6159,  -2.1999,  -5.5962,  -3.3901,  -2.5146,
          -0.7877,  -4.8807],
        [-11.0531,  -9.0076,  -9.4817,  -4.3166, -11.8759,  -6.3798,  -5.4695,
          -1.2384,  -9.5405],
        [ -7.4443,  -6.0553,  -6.2249,  -2.6878,  -7.7536,  -4.3214,  -3.8386,
          -0.6569,  -6.2327],
        [ -3.8022,  -3.3530,  -3.4172,  -1.4257,  -4.0806,  -2.4849,  -1.9557,
          -0.2155,  -3.4757],
        [ -5.0290,  -4.3089,  -4.6866,  -2.1249,  -5.3533,  -3.2988,  -2.7182,
          -0.5055,  -4.4541],
        [ -4.8138,  -4.2404,  -4.4483,  -2.1497,  -5.1847,  -3.2449,  -2.4446,
          -0.8620,  -4.4172],
   

model outputs:  tensor([[ -4.9548,  -4.1609,  -4.3688,  -2.0266,  -5.2605,  -3.5746,  -2.5176,
          -0.5081,  -4.2480],
        [ -4.8177,  -4.1030,  -4.1834,  -1.9749,  -5.1302,  -3.5077,  -2.4832,
          -0.7547,  -4.1400],
        [ -3.7057,  -3.0315,  -3.3530,  -1.5700,  -3.9523,  -2.6937,  -1.9124,
          -0.5614,  -3.3536],
        [ -4.0541,  -3.2808,  -3.5911,  -1.6575,  -4.3170,  -2.9113,  -2.1142,
          -0.5601,  -3.3954],
        [ -4.2597,  -3.5753,  -3.6919,  -1.5942,  -4.5331,  -3.0248,  -2.1666,
          -0.3510,  -3.8402],
        [ -4.3077,  -3.5939,  -3.7472,  -1.9311,  -4.6792,  -2.9183,  -2.2328,
          -0.4221,  -3.7776],
        [ -6.9041,  -6.0918,  -6.3474,  -3.1129,  -7.3667,  -5.4190,  -3.4878,
          -0.8810,  -6.6329],
        [ -3.9680,  -3.3646,  -3.4863,  -1.7410,  -4.2579,  -2.9547,  -2.0320,
          -0.6872,  -3.7020],
        [ -5.3523,  -4.4498,  -4.7369,  -2.3693,  -5.7137,  -3.8418,  -2.7643,
          -0.6718,  -4.5495],
   

model outputs:  tensor([[-4.8483e+00, -4.2087e+00, -4.2476e+00, -2.3393e+00, -5.2990e+00,
         -3.6906e+00, -2.5990e+00, -3.8208e-01, -4.4419e+00],
        [-6.3665e+00, -5.4085e+00, -5.6934e+00, -2.8330e+00, -6.6106e+00,
         -5.0248e+00, -3.3248e+00, -3.4633e-01, -5.6341e+00],
        [-5.4168e+00, -4.3946e+00, -4.6380e+00, -2.4544e+00, -5.3854e+00,
         -4.0283e+00, -2.7458e+00, -4.5867e-01, -4.7773e+00],
        [-6.4864e+00, -5.2286e+00, -5.5045e+00, -2.9099e+00, -6.8459e+00,
         -5.0621e+00, -3.1316e+00, -4.5496e-01, -5.5712e+00],
        [-5.7987e+00, -4.5340e+00, -5.2883e+00, -2.7390e+00, -6.0174e+00,
         -4.4303e+00, -3.1785e+00, -2.7546e-01, -5.1170e+00],
        [-6.7180e+00, -5.5010e+00, -5.7495e+00, -3.1621e+00, -7.1878e+00,
         -5.1515e+00, -3.1496e+00, -3.9311e-01, -5.9174e+00],
        [-4.4872e+00, -3.6988e+00, -3.7469e+00, -2.1387e+00, -4.7083e+00,
         -3.2681e+00, -2.2489e+00, -3.4286e-01, -3.7560e+00],
        [-7.2862e+00, -6.4334e+0

model outputs:  tensor([[-5.7029e+00, -4.8798e+00, -5.0619e+00, -3.0597e+00, -6.1311e+00,
         -4.9618e+00, -3.0390e+00, -1.8523e-01, -5.1171e+00],
        [-8.3546e+00, -6.7252e+00, -7.1680e+00, -4.4332e+00, -8.9772e+00,
         -6.2749e+00, -4.4424e+00, -5.3094e-01, -7.1558e+00],
        [-4.7562e+00, -4.2452e+00, -4.4085e+00, -2.8247e+00, -5.3361e+00,
         -4.1842e+00, -2.7290e+00, -9.8606e-02, -4.4219e+00],
        [-6.5558e+00, -5.1667e+00, -5.3266e+00, -3.3699e+00, -6.8732e+00,
         -5.1234e+00, -3.5388e+00, -2.7019e-01, -5.5220e+00],
        [-5.0182e+00, -4.4051e+00, -4.8088e+00, -3.1212e+00, -5.4034e+00,
         -4.4350e+00, -3.0019e+00, -2.4370e-01, -4.6865e+00],
        [-4.2615e+00, -3.5639e+00, -3.8602e+00, -2.3858e+00, -4.5250e+00,
         -3.3221e+00, -2.4424e+00, -6.0181e-02, -3.6233e+00],
        [-6.4281e+00, -5.1205e+00, -5.6806e+00, -3.6179e+00, -6.7439e+00,
         -5.3338e+00, -3.6649e+00, -1.9243e-01, -5.6183e+00],
        [-7.0119e+00, -5.8366e+0

model outputs:  tensor([[-4.8817e+00, -4.0189e+00, -4.4443e+00, -2.8264e+00, -5.3121e+00,
         -4.2282e+00, -2.9761e+00,  8.4169e-02, -4.4008e+00],
        [-4.1833e+00, -3.4853e+00, -3.5716e+00, -2.3290e+00, -4.4487e+00,
         -3.4211e+00, -2.2077e+00, -3.0646e-03, -3.5803e+00],
        [-5.7717e+00, -4.7206e+00, -5.1763e+00, -3.3661e+00, -6.1644e+00,
         -5.1234e+00, -3.0956e+00, -1.4260e-01, -5.3544e+00],
        [-6.7492e+00, -5.8820e+00, -6.2153e+00, -4.0164e+00, -7.4412e+00,
         -5.5989e+00, -3.6917e+00,  1.6663e-01, -6.2093e+00],
        [-3.7020e+00, -2.9694e+00, -3.2356e+00, -2.2685e+00, -3.8895e+00,
         -3.1651e+00, -2.0860e+00, -1.0073e-01, -3.2313e+00],
        [-7.2608e+00, -6.2518e+00, -6.5330e+00, -4.1356e+00, -8.1126e+00,
         -6.3862e+00, -3.9925e+00, -5.2381e-02, -6.2285e+00],
        [-5.1595e+00, -4.5076e+00, -4.4384e+00, -3.0606e+00, -5.7844e+00,
         -4.2242e+00, -3.0405e+00, -2.8364e-01, -4.6122e+00],
        [-1.1286e+01, -9.5703e+0

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-7.4259e+00, -5.7660e+00, -6.3017e+00, -4.7141e+00, -7.3292e+00,
         -5.7776e+00, -4.1764e+00, -5.4054e-01, -6.5301e+00],
        [-1.1396e+01, -9.7996e+00, -1.0507e+01, -7.6687e+00, -1.2292e+01,
         -9.5086e+00, -7.4649e+00, -5.1602e-01, -1.0838e+01],
        [-5.6829e+00, -4.8282e+00, -5.4601e+00, -3.9520e+00, -6.3138e+00,
         -4.9137e+00, -3.6188e+00, -1.5591e-01, -5.4149e+00],
        [-8.3225e+00, -6.8009e+00, -7.3149e+00, -5.2960e+00, -8.7674e+00,
         -6.8733e+00, -5.1242e+00, -3.7429e-01, -7.0410e+00],
        [-7.2408e+00, -6.1852e+00, -6.8582e+00, -5.1910e+00, -7.7946e+00,
         -5.8803e+00, -4.1831e+00, -5.2525e-01, -6.5244e+00],
        [-7.6077e+00, -6.1888e+00, -6.5530e+00, -4.8173e+00, -7.8337e+00,
         -6.1459e+00, -4.4406e+00, -2.6860e-01, -6.5693e+00],
        [-1.1380e+01, -9.8664e+00, -1.0096e+01, -7.3263e+00, -1.2430e+01,
         -9.4294e+00, -6.9401e+00, -9.1037e-01, -9.9808e+00],
        [-8.3764e+00, -7.1920e+0

Labels:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 

model outputs:  tensor([[ -6.1834,  -5.0673,  -5.4078,  -4.4471,  -6.7763,  -4.7069,  -4.1911,
          -0.5054,  -5.4454],
        [ -7.8579,  -6.1439,  -6.7792,  -5.8645,  -8.4471,  -5.5820,  -4.8416,
          -1.3810,  -6.5381],
        [ -7.2634,  -6.0494,  -6.3156,  -4.9105,  -7.6165,  -5.5375,  -4.9579,
          -0.4615,  -6.2177],
        [-11.0664,  -9.0626,  -9.5085,  -7.5967, -11.5749,  -8.5183,  -7.1356,
          -1.0063,  -9.6308],
        [-10.3619,  -9.1441,  -9.4961,  -7.5299, -11.2677,  -8.0698,  -6.6213,
          -1.2944,  -9.6722],
        [ -5.3739,  -4.2402,  -4.5269,  -3.7137,  -5.6167,  -4.0518,  -3.4046,
          -0.3930,  -4.4678],
        [ -5.0746,  -4.2123,  -4.2787,  -3.6290,  -5.1501,  -3.6820,  -3.3875,
          -0.3828,  -4.6154],
        [-10.9197,  -8.8848,  -9.6832,  -7.8852, -11.2423,  -8.3841,  -6.8931,
          -1.0732,  -9.3569],
        [ -6.3453,  -5.2715,  -5.7139,  -4.7502,  -6.5829,  -4.6898,  -4.2019,
          -0.5597,  -5.7835],
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-6.8094e+00, -5.3640e+00, -5.9539e+00, -4.6810e+00, -6.9669e+00,
         -4.2064e+00, -4.1809e+00,  4.9195e-04, -5.6406e+00],
        [-5.0603e+00, -4.1634e+00, -4.2845e+00, -3.5105e+00, -5.5283e+00,
         -3.2835e+00, -3.6227e+00, -1.6088e-01, -4.3114e+00],
        [-5.7007e+00, -4.6633e+00, -5.1504e+00, -4.2748e+00, -6.0080e+00,
         -3.7685e+00, -3.7579e+00,  2.8105e-02, -4.8344e+00],
        [-9.4034e+00, -7.7377e+00, -8.8775e+00, -6.8494e+00, -9.9061e+00,
         -6.4244e+00, -6.3905e+00, -5.3580e-01, -8.0373e+00],
        [-3.4048e+00, -2.8069e+00, -2.9990e+00, -2.4472e+00, -3.5978e+00,
         -2.2282e+00, -2.3323e+00, -9.9769e-02, -2.9747e+00],
        [-6.0133e+00, -5.0363e+00, -5.5578e+00, -4.5947e+00, -6.6816e+00,
         -4.2333e+00, -4.0159e+00, -7.1786e-02, -5.2244e+00],
        [-9.2415e+00, -7.8785e+00, -8.3063e+00, -6.7298e+00, -9.8677e+00,
         -6.0297e+00, -6.4230e+00, -2.7078e-01, -7.9922e+00],
        [-6.5231e+00, -5.7128e+0

model outputs:  tensor([[-3.6942e+00, -3.0816e+00, -3.2861e+00, -2.7425e+00, -3.9179e+00,
         -2.2542e+00, -2.5358e+00,  2.2885e-01, -3.3723e+00],
        [-7.2655e+00, -6.1083e+00, -6.0368e+00, -5.1470e+00, -7.6044e+00,
         -4.3888e+00, -5.0121e+00,  5.6242e-01, -6.0921e+00],
        [-9.1921e+00, -7.2800e+00, -8.2928e+00, -6.6200e+00, -9.6053e+00,
         -5.1046e+00, -6.1815e+00,  6.5912e-01, -7.8691e+00],
        [-6.1713e+00, -5.2954e+00, -5.5884e+00, -4.7545e+00, -6.4724e+00,
         -3.8167e+00, -4.3550e+00,  2.8996e-01, -5.6038e+00],
        [-5.7219e+00, -4.6680e+00, -5.0934e+00, -4.2329e+00, -6.2957e+00,
         -3.7213e+00, -4.1783e+00,  1.4572e-01, -4.9259e+00],
        [-6.2025e+00, -5.1983e+00, -5.4864e+00, -4.6997e+00, -6.6814e+00,
         -3.9116e+00, -4.0088e+00,  3.3301e-01, -5.2587e+00],
        [-6.7598e+00, -5.4895e+00, -6.0685e+00, -4.8713e+00, -7.0594e+00,
         -4.1076e+00, -4.5545e+00,  5.3000e-01, -5.3309e+00],
        [-3.8252e+00, -3.2861e+0

model outputs:  tensor([[-7.8513, -6.3554, -6.8703, -5.4137, -8.2436, -4.4538, -5.2985,  0.2556,
         -6.6446],
        [-6.3266, -5.1187, -5.4872, -4.4386, -6.8943, -3.5220, -4.4549,  0.1356,
         -5.4872],
        [-3.5196, -2.9854, -3.4061, -2.7625, -4.0084, -2.2667, -2.8081,  0.1879,
         -3.2530],
        [-4.3946, -3.7244, -3.9425, -3.1976, -4.8542, -2.6392, -3.2778,  0.2266,
         -4.0206],
        [-6.2789, -5.0195, -5.3452, -4.4717, -6.5454, -3.8482, -4.2905,  0.3575,
         -5.3789],
        [-3.6858, -2.8912, -3.2352, -2.5231, -3.7549, -2.0121, -2.6288,  0.2319,
         -3.0072],
        [-4.2799, -3.3910, -3.8091, -3.1673, -4.6260, -2.4199, -3.1019,  0.0455,
         -3.6165],
        [-5.1820, -4.1069, -4.7248, -3.6915, -5.4578, -3.1371, -3.6574,  0.4221,
         -4.6247],
        [-6.0723, -5.1864, -5.3267, -4.5962, -6.2532, -3.3303, -4.0994,  0.2661,
         -5.1210],
        [-6.5013, -5.5125, -5.8823, -4.8328, -6.7961, -3.7516, -4.6889,  0.5075,
   

predictions:  tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

model outputs:  tensor([[-8.3282, -6.3939, -7.0232, -5.7485, -8.3980, -4.2398, -5.7786, -0.7589,
         -7.0871],
        [-5.8931, -5.0348, -5.3593, -3.9554, -6.4062, -3.1287, -4.2123, -0.5058,
         -5.1777],
        [-6.8334, -5.7457, -6.4649, -5.3005, -7.8403, -3.5993, -5.1016, -0.6698,
         -6.0885],
        [-4.5780, -4.0061, -4.2154, -3.4137, -4.9587, -2.4947, -3.4050, -0.3355,
         -4.1954],
        [-3.6054, -2.9225, -3.1631, -2.4985, -3.6018, -1.8562, -2.4264, -0.3162,
         -3.0317],
        [-6.5115, -5.2558, -5.7712, -4.6949, -6.6389, -3.1507, -4.5644, -0.5120,
         -5.5905],
        [-5.8181, -4.8993, -5.2986, -4.3016, -6.3651, -2.9838, -4.2545, -0.4175,
         -5.0275],
        [-4.9883, -4.3273, -4.6825, -3.7201, -5.4252, -2.5291, -3.8170, -0.3765,
         -4.5893],
        [-4.8817, -4.2410, -4.3988, -3.6759, -5.1449, -2.3902, -3.6146, -0.3196,
         -4.2231],
        [-6.5106, -5.3888, -5.5003, -4.8044, -6.6349, -3.2105, -4.6770, -0.6711,
   

model outputs:  tensor([[-5.1970, -4.3251, -4.5275, -3.7613, -5.4217, -2.6117, -3.7390, -0.6708,
         -4.6281],
        [-4.3956, -3.7875, -4.1859, -3.3325, -4.8616, -2.3294, -3.2409, -0.4561,
         -3.7806],
        [-6.1115, -5.0338, -5.4577, -4.2935, -6.3326, -3.2072, -4.5076, -0.9414,
         -5.3708],
        [-5.6814, -4.4025, -4.8063, -3.9378, -5.8442, -2.9899, -4.2237, -0.7510,
         -4.6324],
        [-5.5254, -4.7739, -5.0210, -4.2426, -5.9528, -2.7063, -3.9575, -0.6941,
         -5.0067],
        [-5.0585, -4.4874, -4.6906, -3.8962, -5.5987, -2.5853, -3.9836, -0.6369,
         -4.4936],
        [-4.7072, -3.7617, -4.1223, -3.2771, -4.9277, -2.1214, -3.4526, -0.3715,
         -3.9823],
        [-7.1406, -6.1024, -6.4498, -5.5950, -7.5543, -3.5760, -5.4186, -0.8012,
         -6.4082],
        [-3.3352, -2.8572, -2.8754, -2.3693, -3.4300, -1.5909, -2.3907, -0.4062,
         -2.7923],
        [-6.4134, -5.4185, -5.5114, -4.2588, -6.5835, -3.1250, -4.5937, -0.7095,
   

model outputs:  tensor([[-8.5594, -7.1908, -7.7534, -6.5461, -9.3975, -4.2554, -6.0883, -1.3341,
         -8.0567],
        [-3.6479, -3.1378, -3.1310, -2.5830, -3.8771, -1.7809, -2.8661, -0.4073,
         -3.2145],
        [-7.1669, -5.7278, -6.3556, -5.1716, -7.3749, -3.3241, -5.0320, -0.7976,
         -6.2566],
        [-6.7792, -5.9609, -6.1814, -5.1032, -7.3971, -3.4311, -4.9713, -0.9043,
         -6.2694],
        [-6.8440, -5.6445, -6.0179, -4.6640, -6.8798, -3.1091, -5.0216, -0.5903,
         -5.9224],
        [-7.6198, -6.5629, -6.7288, -5.3145, -8.1243, -3.4685, -5.5051, -1.1455,
         -6.8456],
        [-5.3386, -4.5094, -4.9783, -3.9729, -5.7727, -2.6014, -4.0067, -0.6513,
         -4.7747],
        [-4.6721, -3.8522, -4.0386, -3.2770, -4.9621, -2.3667, -3.1979, -0.4124,
         -3.9383],
        [-6.9884, -5.6932, -6.0413, -4.8184, -7.4058, -3.2488, -4.8790, -0.5801,
         -6.2873],
        [-6.4263, -5.2447, -5.5666, -4.5503, -6.9681, -2.9978, -4.8456, -0.7303,
   

model outputs:  tensor([[-8.5012, -7.2094, -7.5837, -6.2840, -9.1823, -4.1680, -6.2796, -0.8625,
         -7.8275],
        [-9.2406, -7.2769, -7.8477, -6.2179, -9.2811, -4.2848, -6.5141, -0.3318,
         -8.0401],
        [-6.9011, -5.9209, -6.4849, -5.3911, -7.4712, -3.4906, -5.1977, -0.4666,
         -6.3778],
        [-4.8934, -3.9023, -4.4349, -3.3860, -5.0688, -2.2763, -3.4054, -0.5344,
         -4.1830],
        [-6.4440, -5.6969, -6.1031, -4.7448, -7.2551, -3.1360, -5.0738, -0.5617,
         -5.8572],
        [-6.8149, -5.8873, -6.2617, -5.1470, -7.0152, -3.5092, -4.9844, -0.4470,
         -6.1779],
        [-9.1553, -8.1956, -8.4118, -6.6611, -9.8031, -4.3184, -6.8232, -1.1028,
         -8.1874],
        [-7.3739, -6.0613, -6.2994, -5.2104, -7.9062, -3.5478, -5.4838, -0.3096,
         -6.5126],
        [-6.8042, -5.4886, -5.8227, -4.9437, -6.7096, -3.1714, -4.8264, -0.4413,
         -6.0955],
        [-6.6672, -5.8205, -6.1085, -4.9125, -7.2563, -3.2127, -5.0548, -0.8736,
   

model outputs:  tensor([[-8.1500e+00, -6.7637e+00, -6.8964e+00, -5.5557e+00, -8.5758e+00,
         -3.8296e+00, -5.9724e+00,  1.5305e-01, -7.1723e+00],
        [-7.0623e+00, -6.0700e+00, -6.2599e+00, -4.7569e+00, -7.4847e+00,
         -3.4216e+00, -5.2118e+00, -6.6689e-02, -6.6185e+00],
        [-8.0966e+00, -6.7706e+00, -7.1901e+00, -5.4747e+00, -8.4707e+00,
         -3.7765e+00, -5.9904e+00, -5.6272e-02, -6.8102e+00],
        [-8.0149e+00, -6.7240e+00, -7.2663e+00, -5.6552e+00, -8.0184e+00,
         -3.6307e+00, -5.7746e+00,  3.3686e-01, -7.0218e+00],
        [-7.1381e+00, -6.1560e+00, -6.4054e+00, -5.3211e+00, -7.6259e+00,
         -3.4613e+00, -5.4655e+00,  4.7187e-01, -6.3144e+00],
        [-5.2389e+00, -4.6250e+00, -4.8569e+00, -3.8518e+00, -5.8486e+00,
         -2.6046e+00, -3.8331e+00,  4.0311e-02, -4.5643e+00],
        [-6.6080e+00, -5.5342e+00, -5.5942e+00, -4.8105e+00, -6.9082e+00,
         -3.1980e+00, -4.5710e+00,  3.5018e-02, -5.7424e+00],
        [-7.3171e+00, -6.0696e+0

model outputs:  tensor([[-7.3472e+00, -6.3260e+00, -6.7156e+00, -4.9594e+00, -7.7594e+00,
         -3.8280e+00, -5.5411e+00,  3.2482e-04, -6.6613e+00],
        [-4.7271e+00, -4.1291e+00, -4.1852e+00, -3.3214e+00, -5.0275e+00,
         -2.3293e+00, -3.3900e+00,  6.2115e-02, -4.1536e+00],
        [-8.8253e+00, -7.3814e+00, -7.4198e+00, -6.0832e+00, -9.5357e+00,
         -4.1452e+00, -6.5995e+00,  1.6988e-01, -7.7739e+00],
        [-6.3872e+00, -5.4115e+00, -5.6925e+00, -4.4367e+00, -6.6215e+00,
         -3.0966e+00, -4.7133e+00,  6.0477e-02, -5.5109e+00],
        [-7.5629e+00, -6.9591e+00, -7.1075e+00, -5.5267e+00, -8.6438e+00,
         -3.7400e+00, -5.5866e+00,  2.9674e-01, -6.9998e+00],
        [-3.9693e+00, -3.2794e+00, -3.4360e+00, -2.6434e+00, -4.0887e+00,
         -1.9886e+00, -2.7686e+00,  2.1298e-01, -3.4026e+00],
        [-7.1650e+00, -6.0837e+00, -6.1376e+00, -4.9473e+00, -7.4043e+00,
         -3.4603e+00, -5.1248e+00,  3.8875e-01, -6.3417e+00],
        [-7.9717e+00, -6.9432e+0

KeyboardInterrupt: 