In [1]:
import os
if "x_perceiver" not in os.listdir():
    os.chdir("/home/kh701/pycharm/healnet/")
import torch
from torch import nn
import multiprocessing
import torchvision
import numpy as np
import torchvision.transforms as transforms
import einops
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from healnet.models.explainer import Explainer
pd.set_option('display.max_columns', 50)
pd.set_option('display.max_rows', 50)

from healnet.utils import Config, flatten_config
from healnet.etl import TCGADataset
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

    
%reload_ext autoreload
%autoreload 2

## Import data

In [2]:
# get dataloaders
config = Config("config/main_gpu.yml").read()
config = flatten_config(config) # TODO - refactor to other 

blca = TCGADataset(
    dataset="blca", 
    config=config, 
    level=2, 
    sources=["omic"]
)

brca = TCGADataset(
    dataset="brca", 
    config=config, 
    level=2, 
    sources=["omic"]
)



Filled 0 missing values with mean
Missing values per feature: 
 Series([], dtype: int64)
Slides available: 436
Omic available: 437
Overlap: 436
Filtering out 1 samples for which there are no omic data available
Dataloader initialised for blca dataset
Dataset: BLCA
Molecular data shape: (436, 2191)
Molecular/Slide match: 436/436
Slide level count: 4
Slide level dimensions: ((79968, 79653), (19992, 19913), (4998, 4978), (2499, 2489))
Slide resize dimensions: w: 1024, h: 1024
Sources selected: ['omic']
Censored share: 0.539
Survival_bin_sizes: {0: 72, 1: 83, 2: 109, 3: 172}
Filled 0 missing values with mean
Missing values per feature: 
 Series([], dtype: int64)
Slides available: 1019
Omic available: 1022
Overlap: 1019
Filtering out 3 samples for which there are no omic data available
Dataloader initialised for brca dataset
Dataset: BRCA
Molecular data shape: (1019, 2922)
Molecular/Slide match: 1019/1019
Slide level count: 3
Slide level dimensions: ((35855, 34985), (8963, 8746), (2240, 218

In [98]:
# get tabular data
blca_loader = DataLoader(
    blca, 
    batch_size=1, 
    shuffle=True, 
    num_workers=multiprocessing.cpu_count()-1
)
[sample], censorship, event_time, y_disc = next(iter(blca_loader))

In [99]:
sample.shape

torch.Size([1, 1, 2183])

## Tabular self-supervised pre-training

To start with, we want to build and encoder-decoder model which trains a cross-attention unit as the encoder, which can later on be deployed in the iterative model. We then want to benchmark the performance with pan-cancer pre-training vs. without pre-training. 

In [81]:
from healnet.models.healnet import Attention, PreNorm

class AttentionEncoder(nn.Module): 
    """
    Simple encoder that uses fourier encoding, pre-norm and cross-attention to encode the input features into a latent array 
    of size (num_latents x latent_dim). Takes in both the input tensors as well as a randomly initialised latent 
    array as the input. 
    """
    def __init__(self, 
                 input_channels: int,
                 latent: torch.Tensor, 
                 input_axis: int = 1, 
                 attn_dropout: float = 0.1,
                 num_heads: int = 4, 
                 num_freq_bands: int=8, 
                 ):    
        super().__init__()
        
        self.input_channels = input_channels
        self.input_axis = input_axis
        self.attn_dropout = attn_dropout
        self.num_heads = num_heads
        
        
        # fourier_channels = (input_axis * ((num_freq_bands * 2) + 1))
        # input_dim = fourier_channels + input_channels
        input_dim = input_channels
                
        latent_dim = latent.shape[-1] # required for PreNorm layer
        # simple single attention unit
        enc = PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=num_heads, dim_head=num_heads, dropout=attn_dropout), context_dim=input_dim)
        
        self.layers = nn.ModuleList([enc])
        
    def forward(self, latent: torch.Tensor, context: torch.Tensor):
        """
        Note: context is the data, x is the latent
        Args:
            latent: 
            context: 

        Returns:

        """
        for layer in self.layers:
            latent = layer(x=latent, context=context)
        return latent


The decoder often needs to be different depending on the modality, so let's implement modality-specific decoders while trying to have a relatively general-purpose encoder that we can plug into the pipeline.

Note that we may change this later down the line. 

In [82]:
class TabularDecoder(nn.Module):
    """
    Decoder suited for tabular data. We use the following: 
    - Skip connections: faster and more stable training
    - Batch normalisation: stabilises the activations and speeds up training
    - Activation: Output layer to map back to output dimensions, corresponding to the original data dims
    Tries to reconstruct the original input given the latent
    """
    def __init__(self, latent_dim: int, output_dim: int):
        super(TabularDecoder, self).__init__()
        
        layers = []
        hidden_dims = [128, 256] # may refactor as hyperparameter later
        
        for hidden_dim in hidden_dims: 
            layers.extend([
                nn.Linear(in_features=latent_dim, out_features=hidden_dim), 
                nn.LeakyReLU(), 
                # nn.BatchNorm1d(hidden_dim, track_running_stats=False), 
                nn.Dropout(0.5)
            ])
            
            latent_dim = hidden_dim # update for next layer
        
        # final layer to reconstruct output
        layers.append(nn.Linear(latent_dim, output_dim))
        
        self.decode = nn.Sequential(*layers)
        print(self.decode)
        
    def forward(self, latent: torch.Tensor):
        return self.decode(latent)
    
    
        
    

Finally, putting it all together in the encoder-decoder model


torch.float32

In [102]:
from typing import *

class TabPretrainer(nn.Module): 
    """
    Encoder-decoder model for pre-training tabular data.
    # TODO - refactor abstract base class for initialisations 
    """
    def __init__(self,
                 sample: torch.Tensor,
                 # input_channels: int,
                 latent_shape: List[int],
                 input_axis: int = 1,
                 attn_dropout: float = 0.1,
                 num_heads: int = 4,
                 num_freq_bands: int=8,
                 ):
        super().__init__()
        self.input_channels = sample.shape[-1]
        self.input_axis = input_axis
        self.num_latents, self.latent_dim = latent_shape
        self.attn_dropout = attn_dropout
        self.num_heads = num_heads
        self.num_freq_bands = num_freq_bands
        
        
        # randomly initialise latent
        self.latent = nn.Parameter(torch.randn(self.num_latents, self.latent_dim))
        
        # encoder
        self.encoder = AttentionEncoder(
            input_channels=self.input_channels, 
            latent=self.latent, 
            input_axis=self.input_axis, 
            attn_dropout=attn_dropout, 
            num_heads=num_heads, 
            num_freq_bands=num_freq_bands
        )
        
        # decoder
        self.decoder = TabularDecoder(
            latent_dim=self.latent_dim, 
            output_dim=self.input_channels
        )
        
    def forward(self, x: torch.Tensor):
        # get batch dim
        b = x.shape[0]
        
        # expand latent to batch size
        if len(self.latent.shape) == 2:
            self.latent = nn.Parameter(einops.repeat(self.latent, "n d -> b n d", b=b))
        
        # encode
        self.latent.data = self.encoder(latent=self.latent, context=x).data + self.latent.data
        
        # decode, reconstructed x
        rec_x = self.decoder(self.latent)
        return rec_x
    
    def get_latent(self):
        return self.latent

Next, we need to think about tabular loss functions. Here, we can explore both reconstruction losses and contrastive losses. 

In [8]:
class TabularLoss(nn.Module):
    """
    Reconstruction loss functions for tabular data. We use two types which are commonly used with continuous data: 
    - Mean squared error
    - Constrastive loss, measured as cosine distance between the original and reconstructed data
    We seek to minimise both objectives.
    """
    def __init__(self,
                 method: str = "mse",
                 reduction: str = "mean",
                 ):
        super().__init__()
        assert method in ["mse", "contrastive"], "Loss type not recognised"
        self.loss_type = method
        self.reduction = reduction
        
        if method == "mse":
            self.loss = nn.MSELoss(reduction=reduction)
        elif method == "contrastive":
            self.loss = nn.CosineEmbeddingLoss(reduction=reduction)
            
    def __call__(self, x: torch.Tensor, rec_x: torch.Tensor):
        return self.loss(x, rec_x)
    

Finally, we write a pre-training loop that we can use for pre-training across cancer sites. 

In [9]:
# get overlap between omic columns
col1 = blca.omic_df.columns
col2 = brca.omic_df.columns
print(len(col1), len(col2), len(set(col1).intersection(col2)))

2191 2922 1758


In [95]:
blca.features

# types of features
# continuous: *_rnaseq, age
# categorical: *_mut, *_cnv 

Unnamed: 0,age,is_female,AAK1_rnaseq,AATK_rnaseq,ABCB1_rnaseq,ABCG2_rnaseq,ABI1_rnaseq,ABL1_rnaseq,ABL2_rnaseq,ACE_rnaseq,ACKR1_rnaseq,ACKR3_rnaseq,ACSL3_rnaseq,ACSL6_rnaseq,ACVR1B_rnaseq,ACVR1C_rnaseq,ACVR1_rnaseq,ACVR2A_rnaseq,ACVR2B_rnaseq,ACVRL1_rnaseq,ADAM10_rnaseq,ADAM17_rnaseq,ADCK1_rnaseq,ADCK2_rnaseq,ADCK5_rnaseq,...,UTRN_mut,VCAN_mut,VPS13B_mut,VPS13C_mut,VPS13D_mut,WDFY3_mut,WNK1_mut,XIRP2_mut,XIST_mut,ZDBF2_mut,ZFHX3_mut,ZFHX4_mut,ZFP36L1_mut,ZFYVE26_mut,ZFYVE9_mut,ZNF236_mut,ZNF292_mut,ZNF423_mut,ZNF521_mut,ZNF536_mut,ZNF626_mut,ZNF804A_mut,ZNF91_mut,ZZEF1_mut,RAS_mut
0,63,0,-0.6734,-0.4660,0.8401,-0.2222,2.2318,-0.8171,0.8051,-0.1250,-0.2976,1.2538,-0.3237,-0.1429,0.5258,-0.0748,-0.2048,-0.3004,0.2998,-0.6414,1.2149,1.1643,0.3720,-0.2883,-0.1974,...,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1
1,66,0,2.4277,-0.3853,0.1104,-0.2183,-0.0952,-0.6255,0.0970,-0.4911,-0.1779,-0.4134,-0.1501,-0.1576,-0.3597,0.4555,-1.0758,0.3252,1.7109,-0.5763,2.5860,1.5608,-0.6966,0.1801,-0.3164,...,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
2,66,0,2.4277,-0.3853,0.1104,-0.2183,-0.0952,-0.6255,0.0970,-0.4911,-0.1779,-0.4134,-0.1501,-0.1576,-0.3597,0.4555,-1.0758,0.3252,1.7109,-0.5763,2.5860,1.5608,-0.6966,0.1801,-0.3164,...,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
3,69,0,1.1340,-0.4110,0.1572,0.0752,0.0566,-1.3448,-0.3876,1.0335,-0.3683,-0.3736,-0.3294,-0.1807,0.8215,-0.7729,-0.7901,-0.9142,-0.2716,-0.2526,1.2477,0.8202,-0.1294,0.7846,-0.2564,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,59,1,-0.5311,0.1418,-0.0998,-0.2493,-0.6956,-0.3696,-0.1672,-0.7257,-0.3450,-0.1465,0.2727,0.3077,1.3352,2.3315,0.2386,1.6382,-0.2124,-0.7441,0.9661,0.6493,-1.2289,-0.0261,-0.1046,...,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
432,71,0,2.8284,0.9219,-0.4711,-0.1561,-0.7569,-0.0186,2.2843,1.9383,-0.3507,-0.4448,-0.4321,-0.1794,-1.0555,-0.6594,1.4145,-1.0607,-0.2349,0.9236,-0.8023,1.6165,-1.0423,-1.2719,-0.7105,...,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
433,61,1,0.9422,0.2662,-0.5276,-0.3078,-0.5164,-0.3653,1.6644,0.4936,-0.3722,-0.5500,-0.3539,-0.1324,-1.1019,-0.7735,0.0458,-1.1459,-0.6342,0.4675,-1.0416,-0.9597,-1.3496,0.4849,1.6007,...,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1
434,60,1,-0.3000,-0.5301,-0.5559,0.4720,0.0597,-0.1817,3.6724,0.1842,-0.3892,-0.6284,0.2315,-0.1468,-1.1972,-0.5964,-0.2927,-0.9101,-0.6431,-0.3869,-0.5002,-0.0913,-0.6892,-0.6854,1.1884,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
435,62,1,3.2208,-0.2592,-0.8130,-0.1423,-1.2421,-1.4423,-0.6631,-0.9650,-0.3868,1.9972,0.0008,-0.1583,-1.4766,-0.7790,-1.2992,-1.0639,0.5145,-0.8358,-0.1595,-0.0378,-0.3054,0.7819,1.9629,...,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1


In [89]:
from tqdm import tqdm

torch.set_printoptions(sci_mode=False)


def pretrain_loop(
        data: TCGADataset,
        batch_size: int, 
    ):
    
    loader = DataLoader(
        data, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=multiprocessing.cpu_count()-1
    )
    [omic_sample], _, _, _ = next(iter(loader))   
    
    
    model = TabPretrainer(
        sample = omic_sample,
        # input_channels=omic_sample.shape[-1], 
        input_axis=1, 
        latent_shape=[256, 32], 
        attn_dropout=0.1, 
        num_heads=4,
        num_freq_bands=8
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    loss_fn = TabularLoss(method="mse")
    
    for epoch in tqdm(range(10)):
        for idx, batch in enumerate(loader):
            [omic], censorship, event_time, y_disc = batch
            rec_omic = model(omic)
            loss = loss_fn(omic, rec_omic)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # print every 10th batch
            if idx % 100 == 0:
                print(loss)
                print(omic)
                print(rec_omic)
            
    
pretrain_loop(
    data=blca, 
    batch_size=1)
    

Sequential(
  (0): Linear(in_features=32, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.01)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=128, out_features=256, bias=True)
  (4): LeakyReLU(negative_slope=0.01)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=256, out_features=2183, bias=True)
)


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(2.8114, grad_fn=<MseLossBackward0>)
tensor([[[    66.0000,      0.0000,     -0.0004,  ...,      0.0000,
               0.0000,      0.0000]]])
tensor([[[     0.2862,      0.2241,     -0.1490,  ...,      0.1617,
               0.1011,     -0.1182],
         [     0.0690,      0.0426,      0.0028,  ...,     -0.0440,
               0.0984,     -0.1003],
         [    -0.0538,      0.0633,     -0.1094,  ...,     -0.2010,
              -0.1546,     -0.0747],
         ...,
         [    -0.2268,     -0.0646,      0.1491,  ...,     -0.2153,
               0.0527,      0.0817],
         [    -0.0800,      0.1510,     -0.1117,  ...,     -0.1109,
               0.0588,      0.1342],
         [    -0.2746,      0.1651,     -0.0003,  ...,     -0.0946,
               0.0997,      0.2341]]], grad_fn=<ViewBackward0>)
tensor(1.3467, grad_fn=<MseLossBackward0>)
tensor([[[50.0000,  1.0000, -1.2060,  ...,  0.0000,  0.0000,  1.0000]]])
tensor([[[    27.5842,     -0.2846,      0.1688,  ...,      0.1

 10%|█         | 1/10 [00:05<00:47,  5.30s/it]

tensor(5.9094, grad_fn=<MseLossBackward0>)
tensor([[[80.0000,  0.0000, -1.8057,  ...,  0.0000,  0.0000,  1.0000]]])
tensor([[[   70.4145,     0.4202,     0.8595,  ...,     0.1450,
              0.0497,     0.1270],
         [   46.8503,     0.3871,     0.5777,  ...,     0.1016,
              0.0222,     0.1752],
         [   49.7496,     0.3898,     0.5145,  ...,     0.1003,
              0.0303,     0.1176],
         ...,
         [   67.4621,     0.5773,     0.7473,  ...,     0.1210,
              0.0135,     0.1590],
         [   50.6594,     0.3019,     0.3518,  ...,     0.0831,
              0.0086,     0.1310],
         [   52.9864,     0.4540,     0.6558,  ...,     0.1102,
              0.0161,     0.1233]]], grad_fn=<ViewBackward0>)
tensor(1.6493, grad_fn=<MseLossBackward0>)
tensor([[[46.0000,  0.0000, -0.6910,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    60.9179,      0.0647,      1.1114,  ...,     -0.0038,
               0.2147,      0.2177],
         [    60.0517,      

 20%|██        | 2/10 [00:12<00:50,  6.30s/it]

tensor(0.7872, grad_fn=<MseLossBackward0>)
tensor([[[53.0000,  1.0000,  1.3446,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    52.6728,      0.1659,     -0.0653,  ...,      0.3115,
               0.1767,      0.1223],
         [    71.6219,      0.2006,     -0.0610,  ...,      0.3982,
               0.2194,      0.1343],
         [    43.1795,      0.1165,     -0.0258,  ...,      0.2429,
               0.1552,      0.1066],
         ...,
         [    57.4983,      0.1863,     -0.1262,  ...,      0.3237,
               0.1927,      0.1456],
         [    46.9717,      0.1378,      0.0604,  ...,      0.3120,
               0.1601,      0.1390],
         [    37.7055,      0.1338,     -0.0741,  ...,      0.1984,
               0.1208,      0.1106]]], grad_fn=<ViewBackward0>)
tensor(1.3723, grad_fn=<MseLossBackward0>)
tensor([[[60.0000,  0.0000, -0.4841,  ...,  0.0000,  1.0000,  0.0000]]])
tensor([[[    48.0910,      0.3337,      0.0040,  ...,      0.2478,
               0.2922,     -0

 30%|███       | 3/10 [00:17<00:40,  5.80s/it]

tensor(1.3250, grad_fn=<MseLossBackward0>)
tensor([[[90.0000,  1.0000, -1.3975,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    56.6554,      0.1947,      0.0247,  ...,      0.0298,
              -0.0858,     -0.0812],
         [    78.4231,      0.4030,      0.1715,  ...,      0.0720,
              -0.0808,     -0.0739],
         [    72.3756,      0.2020,      0.2598,  ...,      0.0620,
              -0.1138,     -0.1021],
         ...,
         [    64.6823,      0.2011,      0.1277,  ...,      0.0780,
              -0.0839,     -0.1099],
         [    73.4735,      0.2988,      0.1893,  ...,      0.0153,
              -0.0963,     -0.1023],
         [    90.8132,      0.2977,      0.0933,  ...,      0.0955,
              -0.1662,     -0.1403]]], grad_fn=<ViewBackward0>)
tensor(4.9299, grad_fn=<MseLossBackward0>)
tensor([[[54.0000,  0.0000,  1.1411,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[   59.9985,     0.4494,     0.0254,  ...,     0.1771,
              0.1125,     0.6181]

 40%|████      | 4/10 [00:22<00:33,  5.51s/it]

tensor(0.6281, grad_fn=<MseLossBackward0>)
tensor([[[62.0000,  0.0000,  1.7927,  ...,  0.0000,  1.0000,  0.0000]]])
tensor([[[   67.0889,     0.1063,     0.5982,  ...,     0.1284,
              0.1595,     0.1061],
         [   35.3526,     0.1045,     0.2856,  ...,     0.0820,
              0.1173,     0.0828],
         [   57.9124,     0.1593,     0.4837,  ...,     0.1053,
              0.1591,     0.1147],
         ...,
         [   60.8047,     0.1816,     0.6548,  ...,     0.1314,
              0.2011,     0.0538],
         [   61.6667,     0.1381,     0.4300,  ...,     0.1275,
              0.1836,     0.0780],
         [   56.8910,     0.2256,     0.5713,  ...,     0.1133,
              0.1580,     0.0799]]], grad_fn=<ViewBackward0>)
tensor(4.3926, grad_fn=<MseLossBackward0>)
tensor([[[73.0000,  0.0000,  2.7281,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    58.4556,      0.6258,      0.2171,  ...,      0.0994,
              -0.0397,      0.0447],
         [    64.0984,      

 50%|█████     | 5/10 [00:27<00:26,  5.32s/it]

tensor(1.3158, grad_fn=<MseLossBackward0>)
tensor([[[60.0000,  0.0000, -0.1458,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    59.2085,     -0.1113,      0.2274,  ...,     -0.0343,
              -0.1236,      0.0856],
         [    55.0480,     -0.0787,      0.3507,  ...,     -0.1292,
              -0.1091,      0.1304],
         [    40.4069,     -0.0640,      0.4030,  ...,     -0.0095,
              -0.0984,      0.0647],
         ...,
         [    72.3045,     -0.0780,      0.5292,  ...,     -0.1860,
              -0.1316,      0.0386],
         [    65.2534,     -0.2323,      0.2850,  ...,     -0.1103,
              -0.1161,      0.0951],
         [    62.6999,     -0.1073,      0.6269,  ...,     -0.1354,
              -0.1093,      0.0544]]], grad_fn=<ViewBackward0>)
tensor(2.3578, grad_fn=<MseLossBackward0>)
tensor([[[68.0000,  0.0000, -0.8200,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    48.5509,      0.2579,     -0.1082,  ...,      0.0543,
               0.0569,      0

 60%|██████    | 6/10 [00:32<00:20,  5.21s/it]

tensor(0.8511, grad_fn=<MseLossBackward0>)
tensor([[[76.0000,  1.0000, -0.0903,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    59.2018,      0.1619,      0.7460,  ...,      0.0845,
               0.0203,      0.1572],
         [    76.8315,      0.2324,      0.9474,  ...,      0.0114,
               0.0456,      0.2297],
         [    70.1700,      0.2920,      1.0035,  ...,      0.0456,
               0.0306,      0.1538],
         ...,
         [    70.4930,      0.2366,      0.7661,  ...,     -0.0027,
               0.0198,      0.1414],
         [    77.9535,      0.3550,      1.0316,  ...,      0.0641,
               0.0370,      0.1569],
         [    95.0761,      0.2743,      1.0437,  ...,      0.0978,
               0.0395,      0.1399]]], grad_fn=<ViewBackward0>)
tensor(2.2967, grad_fn=<MseLossBackward0>)
tensor([[[75.0000,  0.0000,  0.4331,  ...,  0.0000,  1.0000,  0.0000]]])
tensor([[[    56.1272,      0.4327,      0.7006,  ...,      0.0304,
               0.0265,      0

 70%|███████   | 7/10 [00:37<00:15,  5.16s/it]

tensor(0.7166, grad_fn=<MseLossBackward0>)
tensor([[[50.0000,  0.0000,  0.3542,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    77.4723,      0.6541,     -0.0931,  ...,      0.2203,
               0.1727,      0.5865],
         [    67.8820,      0.4855,     -0.0405,  ...,      0.1503,
               0.1269,      0.4843],
         [    42.5054,      0.4351,      0.0122,  ...,      0.1304,
               0.1186,      0.3043],
         ...,
         [    52.6750,      0.3535,      0.0091,  ...,      0.1441,
               0.0694,      0.3689],
         [    65.5573,      0.5618,      0.0367,  ...,      0.2242,
               0.1227,      0.3993],
         [    65.2112,      0.5782,     -0.1814,  ...,      0.1645,
               0.1913,      0.4989]]], grad_fn=<ViewBackward0>)
tensor(2.3678, grad_fn=<MseLossBackward0>)
tensor([[[57.0000,  0.0000, -0.2452,  ...,  1.0000,  0.0000,  0.0000]]])
tensor([[[   113.4000,      1.2569,      1.2394,  ...,      0.4033,
               0.0184,      0

 80%|████████  | 8/10 [00:43<00:10,  5.30s/it]

tensor(1.0333, grad_fn=<MseLossBackward0>)
tensor([[[65.0000,  0.0000,  0.2180,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    76.0294,      0.3167,      1.4031,  ...,      0.0335,
               0.2552,     -0.0145],
         [    46.4658,      0.2878,      0.8713,  ...,      0.0194,
               0.1253,      0.0909],
         [    53.4063,      0.2968,      0.9068,  ...,     -0.0083,
               0.1882,      0.0103],
         ...,
         [    54.3103,      0.2933,      0.8695,  ...,      0.0692,
               0.1406,      0.0473],
         [    58.9399,      0.1696,      0.9165,  ...,     -0.0187,
               0.1197,      0.0007],
         [    48.9617,      0.2630,      0.9389,  ...,     -0.0358,
               0.0384,      0.0333]]], grad_fn=<ViewBackward0>)
tensor(1.3296, grad_fn=<MseLossBackward0>)
tensor([[[66.0000,  0.0000,  2.0546,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[   76.3845,     0.3484,     0.8452,  ...,     0.0563,
              0.5274,     0.1549]

 90%|█████████ | 9/10 [00:48<00:05,  5.38s/it]

tensor(1.1931, grad_fn=<MseLossBackward0>)
tensor([[[69.0000,  1.0000,  1.0457,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    57.4249,      0.6660,      0.1118,  ...,      0.3605,
               0.1569,      0.2772],
         [    68.2448,      0.4291,      0.1188,  ...,      0.1224,
               0.1939,      0.0949],
         [    57.1778,      0.2097,      0.2779,  ...,      0.0536,
               0.0017,      0.1204],
         ...,
         [    53.6904,      0.1413,      0.1187,  ...,      0.0085,
               0.1349,     -0.0009],
         [    76.3008,      0.6065,      0.1208,  ...,      0.3842,
               0.1497,      0.1964],
         [    41.7606,      0.1448,      0.0315,  ...,      0.0707,
               0.1008,      0.0910]]], grad_fn=<ViewBackward0>)
tensor(1.9294, grad_fn=<MseLossBackward0>)
tensor([[[74.0000,  0.0000, -0.0806,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[[    46.0793,      0.1049,      0.2510,  ...,      0.2301,
              -0.1970,     -0

100%|██████████| 10/10 [00:53<00:00,  5.39s/it]
