In [3]:
import os
if "x_perceiver" not in os.listdir():
    os.chdir("/home/kh701/pycharm/healnet/")
import torch
from torch import nn
import multiprocessing
import torchvision
import numpy as np
import torchvision.transforms as transforms
import einops
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from healnet.models.explainer import Explainer
pd.set_option('display.max_columns', 50)
pd.set_option('display.max_rows', 50)

from healnet.utils import Config, flatten_config
from healnet.etl import TCGADataset
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

    
%reload_ext autoreload
%autoreload 2

## Import data

In [12]:
# get dataloaders
config = Config("config/main_gpu.yml").read()
config = flatten_config(config) # TODO - refactor to other 

blca = TCGADataset(
    dataset="blca", 
    config=config, 
    level=2, 
    sources=["omic"]
)

brca = TCGADataset(
    dataset="brca", 
    config=config, 
    level=2, 
    sources=["omic"]
)



Filled 0 missing values with mean
Missing values per feature: 
 Series([], dtype: int64)
Slides available: 436
Omic available: 437
Overlap: 436
Filtering out 1 samples for which there are no omic data available
Dataloader initialised for blca dataset
Dataset: BLCA
Molecular data shape: (436, 2191)
Molecular/Slide match: 436/436
Slide level count: 4
Slide level dimensions: ((79968, 79653), (19992, 19913), (4998, 4978), (2499, 2489))
Slide resize dimensions: w: 1024, h: 1024
Sources selected: ['omic']
Censored share: 0.539
Survival_bin_sizes: {0: 72, 1: 83, 2: 109, 3: 172}
Filled 0 missing values with mean
Missing values per feature: 
 Series([], dtype: int64)
Slides available: 1019
Omic available: 1022
Overlap: 1019
Filtering out 3 samples for which there are no omic data available
Dataloader initialised for brca dataset
Dataset: BRCA
Molecular data shape: (1019, 2922)
Molecular/Slide match: 1019/1019
Slide level count: 3
Slide level dimensions: ((35855, 34985), (8963, 8746), (2240, 218

In [13]:
# get tabular data
blca_loader = DataLoader(
    blca, 
    batch_size=1, 
    shuffle=True, 
    num_workers=multiprocessing.cpu_count()-1
)
[sample], censorship, event_time, y_disc = next(iter(blca_loader))

In [14]:
sample.shape

torch.Size([1, 1, 2183])

## Tabular self-supervised pre-training

To start with, we want to build and encoder-decoder model which trains a cross-attention unit as the encoder, which can later on be deployed in the iterative model. We then want to benchmark the performance with pan-cancer pre-training vs. without pre-training. 

In [80]:
from healnet.models.healnet import PreNorm, default, temperature_softmax, exists
from einops import rearrange, repeat
from torch import nn, einsum


# class Attention(nn.Module):
#     def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.):
#         super().__init__()
#         inner_dim = dim_head * heads
#         context_dim = default(context_dim, query_dim) # self-attention if context is not provided
#         
#         self.scale = dim_head ** -0.5
#         self.heads = heads
# 
#         self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
#         self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
# 
#         self.dropout = nn.Dropout(dropout)
#         # add leaky relu
#         self.to_out = nn.Sequential(
#             nn.Linear(inner_dim, context_dim),
#             nn.LeakyReLU(negative_slope=1e-2)
#         )
# 
#         self.attn_weights = None
#         # self._init_weights()
# 
#     def _init_weights(self):
#     # Use He initialization for Linear layers
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
#                 # Initialize bias to zero if there's any
#                 if m.bias is not None:
#                     nn.init.zeros_(m.bias)
# 
#     def forward(self, x, context = None, mask = None):
#         h = self.heads
#         print(x.shape)
#         print(context.shape)
# 
#         q = self.to_q(x)
#         context = default(context, x)
#         k, v = self.to_kv(context).chunk(2, dim = -1)
# 
#         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
# 
#         sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# 
#         if exists(mask):
#             mask = rearrange(mask, 'b ... -> b (...)')
#             max_neg_value = -torch.finfo(sim.dtype).max
#             mask = repeat(mask, 'b j -> (b h) () j', h = h)
#             sim.masked_fill_(~mask, max_neg_value)
# 
#         # attention, what we cannot get enough of
#         # attn = sim.softmax(dim = -1)
#         attn = temperature_softmax(sim, temperature=0.5, dim=-1)
#         self.attn_weights = attn
#         print("Attn weights", self.attn_weights.shape)
#         attn = self.dropout(attn)
# 
# 
#         out = einsum('b i j, b j d -> b i d', attn, v)
#         out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
#         out = self.to_out(out)
#         return out

In [None]:
class LatentCrossAttention(nn.Module): 
    pass

class AttentionUpdate(nn.Module): 
    pass


class RecurrentAttention(nn.Module): 
    pass



In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LatentCrossAttention(nn.Module):
    def __init__(self, query_dim=1, latent_dim=32):
        super(LatentCrossAttention, self).__init__()
        
        # Weight matrices for projecting query and context (key)
        self.w_q = nn.Linear(query_dim, latent_dim, bias=False)  # Linear layer will act as our transformation matrix for Q
        self.w_c = nn.Linear(latent_dim, latent_dim, bias=False)    # Transformation matrix for K (context)
        
    def forward(self, query, context):
        # Linear projections of query and context
        q_proj = self.w_q(query)  # [b, 2189, 32]
        k_proj = self.w_c(context)  # [b, 256, 32]
        
        # Calculating attention scores
        S = torch.bmm(q_proj, k_proj.transpose(1, 2))  # [b, 2189, 256]
        
        # Summing over the second dimension to get the required size
        S_summed = torch.sum(S, dim=-1)  # [b, 2189]
        
        # Calculating attention weights
        attn = F.softmax(S_summed, dim=-1)  # [b, 2189]
        
        # element-wise (hadamard) product
        query_prime = attn.unsqueeze(-1) * query  # [b, 2189, 1]
        
        return attn, query_prime

class LatentUpdate(nn.Module):
    def __init__(self, input_size=256, hidden_size=256, num_layers=1):
        super(LatentUpdate, self).__init__()
        
        self.gru = nn.GRU(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True
        )
    
    def forward(self, latent, input_data):
        # Reshape input_data to shape [b, 9, 256]
        input_data = input_data.view(latent.size(0), 9, -1)
        
        # Use latent as the initial hidden state
        _, h_n = self.gru(input_data, latent)
        
        return h_n



# Test case
b = 10
# tabular
num_features = 2189
num_channels = 1 # just one channel
num_latents = 256
latent_dim = 32
# latent_dim
query = torch.randn(b, num_features, num_channels)
context = torch.randn(b, 256, 32)

attention_module = LatentCrossAttention(query_dim=num_channels, latent_dim=latent_dim)
attention_weights, query_prime = attention_module(query, context)

print(attention_weights.shape)  # Expected: torch.Size([64, 2189])
print(query_prime.shape)

latent_update = LatentUpdate()
context_prime = latent_update(context, query_prime)


torch.Size([10, 2189])
torch.Size([10, 2189, 1])


RuntimeError: shape '[10, 9, -1]' is invalid for input of size 21890

In [163]:
print(latent)
print(updated_latent)

tensor([[[0.2788, 0.3828, 0.0111,  ..., 0.9061, 0.6000, 0.9496],
         [0.5129, 0.5112, 0.7071,  ..., 0.1710, 0.8764, 0.9519],
         [0.0934, 0.4184, 0.9597,  ..., 0.2585, 0.3434, 0.8287],
         ...,
         [0.1092, 0.0657, 0.0180,  ..., 0.1797, 0.8958, 0.2064],
         [0.7158, 0.6836, 0.9376,  ..., 0.6679, 0.9081, 0.4610],
         [0.5774, 0.4250, 0.7896,  ..., 0.6631, 0.3798, 0.5673]]])
tensor([[[-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208],
         [-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208],
         [-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208],
         ...,
         [-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208],
         [-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208],
         [-0.1944, -0.0049, -0.5161,  ...,  0.5245,  0.3983, -0.0208]]],
       grad_fn=<RepeatBackward0>)


tensor([[[0.7892, 0.3836, 0.1512,  ..., 0.1156, 0.4322, 0.1931]]])

In [110]:

class AttentionEncoder(nn.Module): 
    """
    Simple encoder that uses fourier encoding, pre-norm and cross-attention to encode the input features into a latent array 
    of size (num_latents x latent_dim). Takes in both the input tensors as well as a randomly initialised latent 
    array as the input. 
    """
    def __init__(self, 
                 input_channels: int,
                 latent: torch.Tensor, 
                 input_axis: int = 1, 
                 attn_dropout: float = 0.1,
                 num_heads: int = 4, 
                 num_freq_bands: int=8, 
                 ):    
        super().__init__()
        
        self.input_channels = input_channels
        self.input_axis = input_axis
        self.attn_dropout = attn_dropout
        self.num_heads = num_heads
        
        
        # fourier_channels = (input_axis * ((num_freq_bands * 2) + 1))
        # input_dim = fourier_channels + input_channels
        input_dim = input_channels
                
        num_latents, latent_dim = latent.shape
        # latent_dim = latent.shape[-1] # required for PreNorm layer
        # num_latents = latent.shape[0]
        # simple single attention unit
        print(latent_dim, input_dim)
        # enc = PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=num_heads, dim_head=num_heads, dropout=attn_dropout), context_dim=latent_dim)
        enc = Attention(query_dim=input_dim, num_latents=num_latents, latent_dim=latent_dim, heads=num_heads, dim_head=num_heads, dropout=attn_dropout)
        # enc = PreNorm
        # enc = PreNorm(latent_dim, Attention(query_dim=input_dim, context_dim=latent_dim, heads=num_heads, dim_head=num_heads, dropout=attn_dropout), context_dim=latent_dim)
        
        self.layers = nn.ModuleList([enc])
        
        print(enc)
        
    def forward(self, query: torch.Tensor, latent: torch.Tensor):
        """
        Note: context is the data, x is the latent
        Args:
            latent: 
            context: 

        Returns:

        """
        for layer in self.layers:
            latent = layer(x=query, context=latent)
        return latent


The decoder often needs to be different depending on the modality, so let's implement modality-specific decoders while trying to have a relatively general-purpose encoder that we can plug into the pipeline.

Note that we may change this later down the line. 

In [26]:
class TabularDecoder(nn.Module):
    """
    Decoder suited for tabular data. We use the following: 
    - Skip connections: faster and more stable training
    - Batch normalisation: stabilises the activations and speeds up training
    - Activation: Output layer to map back to output dimensions, corresponding to the original data dims
    Tries to reconstruct the original input given the latent
    """
    def __init__(self, latent_dim: int, num_latents: int, output_dim: int, method: str = "dense"):
        super(TabularDecoder, self).__init__()
        assert method in ["dense", "conv"], "Decoder type not recognised"
        # check that latent_dim is divisible by 4
        assert num_latents % 4 == 0, "Latent dim must be a multiple of 4"
        layers = []
        
        if method == "dense": 
            
            # flatten latent array (batch, num_latents, latent_dim) -> (batch, num_latents * latent_dim)
            layers.extend([nn.Flatten()]) 
            out_dims = [1024, 512, 256] # may refactor as hyperparameter later
            
            in_dim = latent_dim * num_latents
            for idx, out_dim in enumerate(out_dims):
                
                layers.extend([
                    nn.Linear(in_features=in_dim, out_features=out_dim), 
                    nn.LeakyReLU(), 
                    nn.InstanceNorm1d(out_dim, track_running_stats=False), 
                    nn.Dropout(0.5)
                ])
                
                in_dim = out_dim # update for next layer
            
            # final layer to reconstruct output
            layers.append(nn.Linear(in_dim, output_dim))
        
        elif method == "conv": 
            print(latent_dim, num_latents)
            layers.extend([
                nn.ConvTranspose1d(num_latents, out_channels=int(num_latents/2), kernel_size=4, stride=2, padding=1), 
                nn.BatchNorm1d(int(num_latents/2)),
                nn.LeakyReLU(negative_slope=0.1),
                
                nn.ConvTranspose1d(int(num_latents/2), out_channels=int(num_latents/4), kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(int(num_latents/4)),
                nn.LeakyReLU(negative_slope=0.1),
                
                # If you added any other ConvTranspose layers, ensure the channel sizes match correctly for those as well.
                
                nn.Conv1d(int(num_latents/4), out_channels=1, kernel_size=1, stride=1, padding=0)
            ])
        
        self.decode = nn.Sequential(*layers)
        print(self.decode)
        
    def forward(self, latent: torch.Tensor):
        return self.decode(latent)
    
    
        
    

Finally, putting it all together in the encoder-decoder model


In [78]:
from typing import *

class TabPretrainer(nn.Module): 
    """
    Encoder-decoder model for pre-training tabular data.
    # TODO - refactor abstract base class for initialisations 
    """
    def __init__(self,
                 sample: torch.Tensor,
                 latent_shape: List[int],
                 input_axis: int = 1,
                 attn_dropout: float = 0.1,
                 num_heads: int = 4,
                 num_freq_bands: int=8,
                 ):
        super().__init__()
        self.input_channels = sample.shape[-1]
        self.input_axis = input_axis
        self.num_latents, self.latent_dim = latent_shape  # (n x d) [256, 32]
        # self.latent_dim, self.num_latents = latent_shape
        self.attn_dropout = attn_dropout
        self.num_heads = num_heads
        self.num_freq_bands = num_freq_bands
        
        
        # randomly initialise latent
        self.latent = nn.Parameter(torch.randn(self.num_latents, self.latent_dim))
        
        # encoder
        self.encoder = AttentionEncoder(
            input_channels=self.input_channels, 
            latent=self.latent,
            input_axis=self.input_axis, 
            attn_dropout=attn_dropout, 
            num_heads=num_heads, 
            num_freq_bands=num_freq_bands
        )
        
        # decoder
        self.decoder = TabularDecoder(
            latent_dim=self.latent_dim,
            num_latents=self.num_latents,
            output_dim=self.input_channels,
            method="dense" # using simple encoder to force good representation
        )
        
    def forward(self, x: torch.Tensor):
        
        # expand latent to batch size
        if len(self.latent.shape) == 2:
            # get batch dim
            b = x.shape[0]
            self.latent = nn.Parameter(einops.repeat(self.latent, "n d -> b n d", b=b))
        
        # encode
        # works much better with skip connections
        print(x.shape)
        self.latent.data = self.encoder(query=x, latent=self.latent).data
        # print(self.latent.shape)
        # decode, reconstructed x
        rec_x = self.decoder(self.latent)
        return rec_x
    
    def get_latent(self):
        return self.latent

Next, we need to think about tabular loss functions. Here, we can explore both reconstruction losses and contrastive losses. 

In [48]:
class TabularLoss(nn.Module):
    """
    Reconstruction loss functions for tabular data. We use two types which are commonly used with continuous data: 
    - Mean squared error
    - Constrastive loss, measured as cosine distance between the original and reconstructed data
    We seek to minimise both objectives.
    """
    def __init__(self,
                 method: str = "mse",
                 reduction: str = "mean",
                 ):
        super().__init__()
        assert method in ["mse", "contrastive"], "Loss type not recognised"
        self.loss_type = method
        self.reduction = reduction
        
        if method == "mse":
            self.loss = nn.MSELoss(reduction=reduction)
        elif method == "contrastive":
            self.loss = nn.CosineEmbeddingLoss(reduction=reduction)
            
    def __call__(self, **kwargs):
        return self.loss(**kwargs)
    

Finally, we write a pre-training loop that we can use for pre-training across cancer sites. 

In [29]:
# get overlap between omic columns
col1 = blca.omic_df.columns
col2 = brca.omic_df.columns
print(len(col1), len(col2), len(set(col1).intersection(col2)))

2191 2922 1758


In [30]:
blca.features

# types of features
# continuous: *_rnaseq, age
# categorical: *_mut, *_cnv 

Unnamed: 0,age,is_female,AAK1_rnaseq,AATK_rnaseq,ABCB1_rnaseq,ABCG2_rnaseq,ABI1_rnaseq,ABL1_rnaseq,ABL2_rnaseq,ACE_rnaseq,ACKR1_rnaseq,ACKR3_rnaseq,ACSL3_rnaseq,ACSL6_rnaseq,ACVR1B_rnaseq,ACVR1C_rnaseq,ACVR1_rnaseq,ACVR2A_rnaseq,ACVR2B_rnaseq,ACVRL1_rnaseq,ADAM10_rnaseq,ADAM17_rnaseq,ADCK1_rnaseq,ADCK2_rnaseq,ADCK5_rnaseq,...,UTRN_mut,VCAN_mut,VPS13B_mut,VPS13C_mut,VPS13D_mut,WDFY3_mut,WNK1_mut,XIRP2_mut,XIST_mut,ZDBF2_mut,ZFHX3_mut,ZFHX4_mut,ZFP36L1_mut,ZFYVE26_mut,ZFYVE9_mut,ZNF236_mut,ZNF292_mut,ZNF423_mut,ZNF521_mut,ZNF536_mut,ZNF626_mut,ZNF804A_mut,ZNF91_mut,ZZEF1_mut,RAS_mut
0,63,0,-0.6734,-0.4660,0.8401,-0.2222,2.2318,-0.8171,0.8051,-0.1250,-0.2976,1.2538,-0.3237,-0.1429,0.5258,-0.0748,-0.2048,-0.3004,0.2998,-0.6414,1.2149,1.1643,0.3720,-0.2883,-0.1974,...,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1
1,66,0,2.4277,-0.3853,0.1104,-0.2183,-0.0952,-0.6255,0.0970,-0.4911,-0.1779,-0.4134,-0.1501,-0.1576,-0.3597,0.4555,-1.0758,0.3252,1.7109,-0.5763,2.5860,1.5608,-0.6966,0.1801,-0.3164,...,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
2,66,0,2.4277,-0.3853,0.1104,-0.2183,-0.0952,-0.6255,0.0970,-0.4911,-0.1779,-0.4134,-0.1501,-0.1576,-0.3597,0.4555,-1.0758,0.3252,1.7109,-0.5763,2.5860,1.5608,-0.6966,0.1801,-0.3164,...,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
3,69,0,1.1340,-0.4110,0.1572,0.0752,0.0566,-1.3448,-0.3876,1.0335,-0.3683,-0.3736,-0.3294,-0.1807,0.8215,-0.7729,-0.7901,-0.9142,-0.2716,-0.2526,1.2477,0.8202,-0.1294,0.7846,-0.2564,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,59,1,-0.5311,0.1418,-0.0998,-0.2493,-0.6956,-0.3696,-0.1672,-0.7257,-0.3450,-0.1465,0.2727,0.3077,1.3352,2.3315,0.2386,1.6382,-0.2124,-0.7441,0.9661,0.6493,-1.2289,-0.0261,-0.1046,...,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
432,71,0,2.8284,0.9219,-0.4711,-0.1561,-0.7569,-0.0186,2.2843,1.9383,-0.3507,-0.4448,-0.4321,-0.1794,-1.0555,-0.6594,1.4145,-1.0607,-0.2349,0.9236,-0.8023,1.6165,-1.0423,-1.2719,-0.7105,...,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
433,61,1,0.9422,0.2662,-0.5276,-0.3078,-0.5164,-0.3653,1.6644,0.4936,-0.3722,-0.5500,-0.3539,-0.1324,-1.1019,-0.7735,0.0458,-1.1459,-0.6342,0.4675,-1.0416,-0.9597,-1.3496,0.4849,1.6007,...,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1
434,60,1,-0.3000,-0.5301,-0.5559,0.4720,0.0597,-0.1817,3.6724,0.1842,-0.3892,-0.6284,0.2315,-0.1468,-1.1972,-0.5964,-0.2927,-0.9101,-0.6431,-0.3869,-0.5002,-0.0913,-0.6892,-0.6854,1.1884,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
435,62,1,3.2208,-0.2592,-0.8130,-0.1423,-1.2421,-1.4423,-0.6631,-0.9650,-0.3868,1.9972,0.0008,-0.1583,-1.4766,-0.7790,-1.2992,-1.0639,0.5145,-0.8358,-0.1595,-0.0378,-0.3054,0.7819,1.9629,...,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1


In [157]:
from tqdm import tqdm

torch.set_printoptions(sci_mode=False)


def pretrain_loop(
        data: TCGADataset,
        batch_size: int, 
    ):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    loader = DataLoader(
        data, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=multiprocessing.cpu_count()-1
    )
    [omic_sample], _, _, _ = next(iter(loader))   
    
    
    model = TabPretrainer(
        sample = omic_sample, 
        input_axis=1, 
        latent_shape=[256, 32], # (n_l x d_l)
        attn_dropout=0.1, 
        num_heads=8,
        num_freq_bands=8
    )
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    loss_method = "mse"
    loss_fn = TabularLoss(method=loss_method)
    
    for epoch in tqdm(range(10)):
        for idx, batch in enumerate(loader):
            [omic], censorship, event_time, y_disc = batch
            omic = omic.to(device)
            rec_omic = model(omic)
            # print(rec_omic.shape)
            # print(omic.shape)
            if loss_method == "contrastive":
                # need to pass in larges for contrastive loss
                # using torch.ones to ensure that omic and rec_omic are learned as similar representations
                # note that this is a slight repurposing of the contrastive loss function
                # with this, the loss is just 1-cos(omic, rec_omic)
                loss = loss_fn(input1=omic, input2=rec_omic, target=torch.ones(omic.shape[0]))
            elif loss_method == "mse": 
                loss = loss_fn(input=omic, target=rec_omic)
            # loss = loss_fn(omic, rec_omic)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # print every 10th batch
            if idx % 100 == 0:
                pass
                # print(loss)
                # print(omic)
                # print(rec_omic)
        # print epoch-level stats
        print(f"Epoch {epoch+1} loss: {loss}")
        # final reconstruction
        # print error vector
        # print((omic - rec_omic).abs())
        # print(omic)
        # print(rec_omic)
    return model
        
            
    
tab_model = pretrain_loop( data=blca, batch_size=1)
tab_latent = tab_model.get_latent()
    

32 2183
Attention(
  (query_proj): Linear(in_features=2183, out_features=70048, bias=True)
  (key_proj): Linear(in_features=32, out_features=32, bias=True)
  (value_proj): Linear(in_features=32, out_features=32, bias=True)
)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=8192, out_features=1024, bias=True)
  (2): LeakyReLU(negative_slope=0.01)
  (3): InstanceNorm1d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (4): Dropout(p=0.5, inplace=False)
  (5): Linear(in_features=1024, out_features=512, bias=True)
  (6): LeakyReLU(negative_slope=0.01)
  (7): InstanceNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (8): Dropout(p=0.5, inplace=False)
  (9): Linear(in_features=512, out_features=256, bias=True)
  (10): LeakyReLU(negative_slope=0.01)
  (11): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (12): Dropout(p=0.5, inplace=False)
  (13): Linear(in_features=256, 

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])


  return F.mse_loss(input, target, reduction=self.reduction)


torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])


 10%|█         | 1/10 [00:02<00:26,  2.89s/it]

Epoch 1 loss: 5.1362385749816895
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 20%|██        | 2/10 [00:05<00:23,  2.95s/it]

Epoch 2 loss: 0.6009393930435181
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 30%|███       | 3/10 [00:08<00:20,  2.99s/it]

Epoch 3 loss: 0.7063509225845337
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 40%|████      | 4/10 [00:11<00:17,  2.96s/it]

Epoch 4 loss: 0.6250912547111511
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 50%|█████     | 5/10 [00:14<00:14,  2.98s/it]

Epoch 5 loss: 2.1324405670166016
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 60%|██████    | 6/10 [00:17<00:11,  2.97s/it]

Epoch 6 loss: 3.038796901702881
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 

 70%|███████   | 7/10 [00:20<00:08,  2.98s/it]

Epoch 7 loss: 1.2081642150878906
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 80%|████████  | 8/10 [00:23<00:05,  2.94s/it]

Epoch 8 loss: 0.6560969948768616
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

 90%|█████████ | 9/10 [00:26<00:02,  2.96s/it]

Epoch 9 loss: 1.0824613571166992
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1, 2183])
torch.Size([1, 1,

100%|██████████| 10/10 [00:29<00:00,  2.96s/it]

Epoch 10 loss: 0.9682132005691528





In [13]:
# get encoder attention weights for a test sample
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = tab_model.encoder.to(device)
encoder.eval()
sample = next(iter(blca_loader))
[omic], _, _, _ = sample
omic = omic.to(device)
# initialise latent
latent = torch.randn(256, 32).to(device)
# omic.to(device)
# latent.to(device)
# watch out for leakage here
print(tab_latent)
encoder(latent=tab_latent, context=omic)
# get attn_weights from encoder
encoder.layers[0].fn.attn_weights.shape
# encoder.layers

Parameter containing:
tensor([[[     0.5905,      0.0785,     -0.0000,  ...,      0.3847,
              -0.0079,     -0.0022],
         [     0.7565,      0.3375,      0.3621,  ...,      0.4624,
              -0.0075,     -0.0020],
         [     0.7565,      0.3375,      0.3621,  ...,      0.4624,
              -0.0075,     -0.0020],
         ...,
         [     0.5905,      0.0785,     -0.0000,  ...,      0.3847,
              -0.0079,     -0.0022],
         [     0.7565,      0.3375,      0.3621,  ...,      0.4624,
              -0.0075,     -0.0020],
         [     0.7565,      0.3375,      0.3621,  ...,      0.4624,
              -0.0075,     -0.0020]]], device='cuda:0', requires_grad=True)


torch.Size([4, 256, 1])