In [1]:
from models import *
from saveAndLoad import *
from functools import partial
import math

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class Dataset_test(Dataset):
    def __init__(self):
        tlong = lambda x: torch.tensor(x, dtype=torch.long)
        tlong2d = lambda x: torch.tensor([x], dtype=torch.long)
        tfloat = lambda x: torch.tensor(x, dtype=torch.float32)
        tfloat2d = lambda x: torch.tensor([x], dtype=torch.float32)
        randemb = lambda: torch.rand(640)
        
        self.cancer = [tlong2d(1), tlong2d(3)]
        self.cancer_d = [tlong2d(3), tlong2d(5)]
        self.sex = [tlong2d(0), tlong2d(2)]
        self.age = [tfloat2d(42), tfloat2d(79)]
        self.race = [tlong2d(0), tlong2d(1)]
        self.time = [tfloat2d(float('nan')), tfloat2d(491)]
        self.event = [tlong2d(0), tlong2d(0)]
        self.gene_ids = [
            torch.stack([tlong(1), tlong(2), tlong(3)]),
            torch.stack([tlong(4), tlong(5)])
        ]
        self.gene_emb = [
            torch.stack([randemb(), randemb(), randemb()]),
            torch.stack([randemb(),randemb()])
            ]
        self.maf = [
            torch.stack([tfloat(.7), tfloat(.8), tfloat(.4)]),
            torch.stack([tfloat(.05),tfloat(float('nan'))])
            ]
        self.focal_cna_ids = [
            torch.stack([tlong(1)]),
            torch.stack([tlong(1), tlong(4), tlong(7)])
        ]
        self.focal_cna = [
            torch.stack([tlong(0)]),
            torch.stack([tlong(0), tlong(1), tlong(0)])
        ]
        # self.broad_cna = [                                #broad is binary
        #     torch.tensor([], dtype=torch.long),
        #     torch.stack([tlong(0), tlong(2), tlong(3)])
        # ]

        #--seg data--#
        self.seg_ids = [
            torch.stack([tlong(1), tlong(2)]),
            torch.stack([tlong(4), tlong(5), tlong(14), tlong(23)])
        ]
        self.seg_start = [                              #broad is emb                   
            torch.stack([tfloat(12345*3e-9), tfloat(123456*3e-9)]),
            torch.stack([tfloat(54321*3e-9), tfloat(999321*3e-9), tfloat(123456789*3e-9), tfloat(123456789*3e-9)])
        ]
        self.seg_end = [                                #broad is emb
            torch.stack([tfloat(12459*3e-9), tfloat(125555*3e-9)]),
            torch.stack([tfloat(64321*3e-9), tfloat(1111321*3e-9), tfloat(123467890*3e-9) ,tfloat(123467890*3e-9)])
        ]
        self.seg_mean = [                               #broad is emb
            torch.stack([tfloat(0.5), tfloat(-0.7)]),
            torch.stack([tfloat(0.1), tfloat(-0.2), tfloat(0.3), tfloat(0.3)])
        ]

    def __len__(self):
        return len(self.cancer)

    def __getitem__(self, idx):
        cancer = self.cancer[idx]
        cancer_d = self.cancer_d[idx]
        sex = self.sex[idx]
        age = self.age[idx]
        race = self.race[idx]
        time = self.time[idx]
        event = self.event[idx]
        gene_id = self.gene_ids[idx]
        gene_emb = self.gene_emb[idx]
        maf = self.maf[idx]
        focal_cna_id = self.focal_cna_ids[idx]
        focal_cna = self.focal_cna[idx]
        # broad_cna = self.broad_cna[idx] #broad is binary
        seg_id = self.seg_ids[idx]
        seg_start = self.seg_start[idx]
        seg_end = self.seg_end[idx]
        seg_mean = self.seg_mean[idx]
        # return cancer, sex, age, time, event, gene_emb, maf, focal_cna, broad_cna                                 #broad is binary
        return cancer, cancer_d, sex, age, race, time, event, gene_id, gene_emb, maf, focal_cna_id, focal_cna, seg_id, seg_start, seg_end, seg_mean  #broad is emb

def collate_test(batch, config):
    cancer =    torch.stack([item[0] for item in batch])
    cancer_d =  torch.stack([item[1] for item in batch])
    sex =       torch.stack([item[2] for item in batch])
    age =       torch.stack([item[3] for item in batch])
    race =      torch.stack([item[4] for item in batch])
    time =      torch.stack([item[5] for item in batch])
    event =     torch.stack([item[6] for item in batch])

    def pad_and_mask(batch_, i, pad_val, mask = True):
        unpadded_list = [item[i] for item in batch_]
        padded_list = pad_sequence(unpadded_list, batch_first=True, padding_value=pad_val)
        if mask:
            lengths = [j.size(0) for j in unpadded_list]
            max_length = padded_list.size(1)
            # print('\n-----------------')
            # print('max_length',max_length)
            # print('lengths',lengths)
            # print('torch.arange(max_length).expand(len(lengths), max_length)\n',torch.arange(max_length).expand(len(lengths), max_length))
            mask = torch.arange(max_length).expand(len(lengths), max_length) >= torch.tensor(lengths).unsqueeze(1)
            # print('torch.arange(max_length).expand(len(lengths), max_length) >= torch.tensor(lengths).unsqueeze(1)\n', mask)
            # print('-----------------\n')
            return padded_list, mask
        return padded_list

    gene_id =                       pad_and_mask(batch, 7, config.gene_id_vocab_size, mask = False)
    gene_emb =                      pad_and_mask(batch, 8, 0, mask = False)
    maf =                           pad_and_mask(batch, 9, 0, mask = False)
    focal_cna_id =                  pad_and_mask(batch, 10, config.gene_id_vocab_size, mask = False)
    focal_cna =                     pad_and_mask(batch, 11, config.focal_cna_vocab_size, mask = False)
    # broad_cna, broad_cna_pad_mask = pad_and_mask(batch, 11, config.broad_cna_vocab_size) #broad is binary
    
    seg_id =                        pad_and_mask(batch, 12, config.seg_id_vocab_size, mask = False)
    seg_start =                     pad_and_mask(batch, 13, 0, mask = False)                      #broad is emb
    seg_end =                       pad_and_mask(batch, 14, 0, mask = False)        #broad is emb
    seg_mean =                      pad_and_mask(batch, 15, 0, mask = False)        #broad is emb

    # return (cancer, sex, age, time, event, gene_emb, maf, focal_cna, broad_cna, pad_mask) #broad is binary
    return (cancer, cancer_d, sex, age, race, time, event, gene_id, gene_emb, maf, focal_cna_id, focal_cna, seg_id, seg_start, seg_end, seg_mean) #broad is emb

In [5]:
## add mutational burden
## train model to predict effect of PARP inhibition
## IMPLEMENT PERCENTILE RBF

In [6]:
sex_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_SEX_somatt_data_df.pkl')
race_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_RACE_somatt_data_df.pkl')
data_df = pd.read_pickle('../labeled_data/somatt_data_df.pkl')

class Config_Somatt:
    n_layer: int = 3
    emb_dim: int = 640 #1152 esmC #1536 esm3 #640 esm2
    input_dim: int = 640
    dropout: float = 0.0
    bias: bool = False
    gene_id_vocab_size : int = 1433
    cancer_type_vocab_size: int = 10
    cancer_type_detailed_vocab_size: int = 33
    pooling : str = 'mean'
    norm_fn: nn.Module = nn.LayerNorm
    position_embedding: bool = False
    sex_label_map: dict = sex_label_mapping
    sex_vocab_size: int =  len(sex_label_mapping)
    race_label_map: dict = race_label_mapping
    race_vocab_size: int = len(race_label_mapping)
    num_heads: int = 1
    lin_proj: nn.Module = LinProj
    rbf_params = (0,1,16)
    n_clin_vars: int = 6
    maf_emb_dim: int = 16
    seg_id_vocab_size: int = 46
    broad_cna_vocab_size: int = 2   
    focal_cna_vocab_size: int = 2   
    event_vocab_size: int = 2
    embed_surv: nn.Module = EmbedSurv_rbf_2heads

config = Config_Somatt()
model = Somatt(config)

ds = Dataset_test()

collate_test_with_config = partial(collate_test, config = config)
loader = DataLoader(ds, batch_size=2, shuffle=True, collate_fn=collate_test_with_config)

for batch in loader:
    mask = create_mask(2,3,probs=[.5,.5,.5])
    output = model(batch, mask = mask)
    print(output)

loading data from ../labeled_data/label_mappings/label_mapping_SEX_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_RACE_somatt_data_df.pkl


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f20ab891a20>>
Traceback (most recent call last):
  File "/home/dandreas/.conda/envs/esm3/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


AttributeError: 'Config_Somatt' object has no attribute 'sex_nan_idx'

loading data from ../aa/tumors.pkl
loading data from ../data_processing/map_tumorBarcode_to_clinicalSampleIdx.pkl
loading data from ../aa/assays.pkl
loading data from ../labeled_data/label_mappings/label_mapping_SEX_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_RACE_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_ARM_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_CANCER_TYPE_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_CANCER_TYPE_DETAILED_somatt_data_df.pkl


In [4]:
import torch
import torch.nn as nn

# Assume these loss functions are defined:
classification_loss_fn = nn.CrossEntropyLoss()


# Example training loop snippet
def train_masked_loss(model, dataloader, optimizer, device):
    model.train()
    for batch in dataloader:
        cancer, cancer_d, sex, age, race, time, event, gene_id, gene_emb, maf, focal_cna_id, focal_cna, seg_id, seg_start, seg_end, seg_mean, seg_pad_mask = batch
        mask = create_mask (cancer.size(0), 3, probs=[.5,.5,.5])
        optimizer.zero_grad()
        
        outputs = model(batch, mask=mask)
        
        cancer_logits = outputs[0]          # shape: [B, num_cancer_classes]
        cancer_detailed_logits = outputs[1]   # shape: [B, num_cancer_detailed_classes]
        survival_logits = outputs[5]          # shape: [B, 1] (or [B])
        
        mask_cancer = mask[:, 0].squeeze()          # shape: [B]
        mask_cancer_detailed = mask[:, 1].squeeze()   # shape: [B]
        mask_survival = mask[:, 2].squeeze()          # shape: [B]
        
        # Compute loss only over the masked tokens:
        if mask_cancer.sum() > 0:
            loss_cancer = classification_loss_fn(
                cancer_logits[mask_cancer], 
                cancer[mask_cancer].long().squeeze()
            )
        else:
            loss_cancer = torch.tensor(0.0, device=device)
        
        if mask_cancer_detailed.sum() > 0:
            loss_cancer_detailed = classification_loss_fn(
                cancer_detailed_logits[mask_cancer_detailed], 
                cancer_d[mask_cancer_detailed].long().squeeze()
            )
        else:
            loss_cancer_detailed = torch.tensor(0.0, device=device)
        
        # For survival, stack time and event into a tensor of shape [B, 2]
        survival_targets = torch.cat([time, event], dim=1)
        if mask_survival.sum() > 0:
            masked_survival = survival_targets[mask_survival]
            masked_risk = survival_logits[mask_survival].squeeze()

            # Assume negative_log_partial_likelihood takes survival_targets and survival logits.
            loss_survival = negative_log_partial_likelihood(
                survival_targets[mask_survival],
                survival_logits[mask_survival].squeeze()
            )
        else:
            loss_survival = torch.tensor(0.0, device=device)
        
        # Combine losses (you can also weight them if needed)
        total_loss = loss_cancer + loss_cancer_detailed + loss_survival
        
        total_loss.backward()
        optimizer.step()
        
        print(f"Batch Loss: {total_loss.item():.4f}")

config = Config_Somatt()
model = Somatt(config)
collate_somatt_with_config = partial(collate_somatt, config = config)
ds = Dataset_Somatt(data_df, mut_embeddings, ref_embeddings, tumors, assays, device)
loader = DataLoader(ds, batch_size=100, shuffle=True, collate_fn=collate_somatt_with_config)

train_masked_loss(model, loader, torch.optim.Adam(model.parameters()), device)

164586 samples
147659 unique patients


ValueError: not enough values to unpack (expected 17, got 16)

In [1]:
from models import *
from somatt import *
from saveAndLoad import pickleLoad

tumors = pickleLoad('../aa/tumors.pkl')
map_tumorBarcode_to_clinicalSampleIdx = pickleLoad('../data_processing/map_tumorBarcode_to_clinicalSampleIdx.pkl')
mut_embeddings = np.load('../aa/canonical_mut_norm_embeddings_esm2.npy')
ref_embeddings = np.load('../aa/canonical_ref_embeddings_esm2.npy')
assays = pickleLoad('../aa/assays.pkl')
device = 'cuda:1'
sex_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_SEX_somatt_data_df.pkl')
race_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_RACE_somatt_data_df.pkl')
arm_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_ARM_somatt_data_df.pkl')
cancer_type_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_CANCER_TYPE_somatt_data_df.pkl')
cancer_type_detailed_label_mapping = pickleLoad('../labeled_data/label_mappings/label_mapping_CANCER_TYPE_DETAILED_somatt_data_df.pkl')
n_cancer_types = len(cancer_type_label_mapping)
n_cancer_type_d = len(cancer_type_detailed_label_mapping)
n_seg_id = len(arm_label_mapping)
sex_nan_idx = len(sex_label_mapping)-1
assert math.isnan(sex_label_mapping[sex_nan_idx])
race_nan_idx = len(race_label_mapping)-1
assert math.isnan(race_label_mapping[race_nan_idx])

class Config_Somatt:
    n_layer: int = 3
    emb_dim: int = 640 #1152 esmC #1536 esm3 #640 esm2
    input_dim: int = 640
    dropout: float = 0.1
    bias: bool = False
    gene_id_vocab_size : int = 1433
    cancer_type_vocab_size: int = n_cancer_types
    cancer_type_detailed_vocab_size: int = n_cancer_type_d
    norm_fn: nn.Module = nn.LayerNorm
    position_embedding: bool = False
    sex_nan_idx: int = sex_nan_idx
    sex_vocab_size: int =  len(sex_label_mapping)
    race_nan_idx: int = race_nan_idx
    race_vocab_size: int = len(race_label_mapping)
    num_heads: int = 1
    lin_proj: nn.Module = LinProj
    rbf_params = (0,1,16)
    n_clin_vars: int = 6
    maf_emb_dim: int = 16
    seg_id_vocab_size: int = n_seg_id
    pool_gene: bool = False
    pool_seg: bool = True
    broad_cna_vocab_size: int = 2   
    focal_cna_vocab_size: int = 2   
    event_vocab_size: int = 2
    embed_surv: nn.Module = EmbedSurv_rbf_2heads

# Example usage:
config = Config_Somatt()  # Your defined config class for Somatt.
model = Somatt(config)
collate_somatt_with_config = partial(collate_somatt, config=config)

data_df = pd.read_pickle('../labeled_data/somatt_data_df.pkl')
label_counts = data_df['CANCER_TYPE'].value_counts().to_dict()
min_cancer_type = .01
filter_rarity = lambda x: label_counts[x]>=min_cancer_type
min_cancer_type = int(len(data_df)*min_cancer_type)
data_df = data_df[data_df['CANCER_TYPE'].apply(filter_rarity)]
print(len(data_df))
data_df = data_df.sample(n=10000, random_state=42)

ds = Dataset_Somatt(data_df, mut_embeddings, ref_embeddings, tumors, assays, device)

train_somatt(Somatt, Config_Somatt, ds, data_df, batch_size = 100, saveName=None, test_size=0.2, num_epochs=15, lr=1e-4, device='cuda:1', collate_fn = collate_somatt_with_config)


  from .autonotebook import tqdm as notebook_tqdm


loading data from ../aa/tumors.pkl
loading data from ../data_processing/map_tumorBarcode_to_clinicalSampleIdx.pkl
loading data from ../aa/assays.pkl
loading data from ../labeled_data/label_mappings/label_mapping_SEX_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_RACE_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_ARM_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_CANCER_TYPE_somatt_data_df.pkl
loading data from ../labeled_data/label_mappings/label_mapping_CANCER_TYPE_DETAILED_somatt_data_df.pkl
143530
10000 samples
9907 unique patients

FOLD 1


TRAINING: 100%|██████████| 81/81 [01:01<00:00,  1.32it/s, Epoch=1/15, Loss: 2.2230] 


Epoch 1/15 - Train Loss: 9.8352


TESTING: 100%|██████████| 20/20 [00:07<00:00,  2.74it/s]


Epoch 1 - Test Accuracy: 18.31%, Detailed: 13.76%, Survival C-Index: 0.8313


TRAINING: 100%|██████████| 81/81 [01:00<00:00,  1.34it/s, Epoch=2/15, Loss: 1.3604]


Epoch 2/15 - Train Loss: 8.9377


TESTING: 100%|██████████| 20/20 [00:07<00:00,  2.84it/s]


Epoch 2 - Test Accuracy: 26.26%, Detailed: 18.61%, Survival C-Index: 0.8309


TRAINING: 100%|██████████| 81/81 [00:58<00:00,  1.37it/s, Epoch=3/15, Loss: 3.1847] 


Epoch 3/15 - Train Loss: 8.4144


TESTING: 100%|██████████| 20/20 [00:06<00:00,  2.87it/s]


Epoch 3 - Test Accuracy: 32.37%, Detailed: 19.56%, Survival C-Index: 0.8319


TRAINING:  10%|▉         | 8/81 [00:05<00:52,  1.40it/s, Epoch=4/15, Loss: 9.0424]


KeyboardInterrupt: 

In [2]:
long_tensor = lambda x: torch.tensor(x, dtype=torch.long)
long_tensor([])


tensor([], dtype=torch.int64)