# 🧪 Data preparation

In [1]:
!pip install pyfaidx

Collecting pyfaidx
  Downloading pyfaidx-0.8.1.4-py3-none-any.whl.metadata (25 kB)
Downloading pyfaidx-0.8.1.4-py3-none-any.whl (28 kB)
Installing collected packages: pyfaidx
Successfully installed pyfaidx-0.8.1.4


In [2]:
import torch
import requests
import itertools
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import plotly.express as px
import matplotlib.pyplot as plt

from pyfaidx import Fasta
from google.cloud import bigquery
from collections import defaultdict
from typing import List, Tuple, Dict
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from google.api_core.exceptions import GoogleAPICallError
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, classification_report

In [3]:
client = bigquery.Client()

Using Kaggle's public dataset BigQuery integration.


In [4]:
def get_vogelstein_genes() -> List[str]:
    url = "https://www.oncokb.org/api/v1/utils/cancerGeneList"
    try:
        response = requests.get(url)
        response.raise_for_status()
        cancer_genes = response.json()
        
        # Filter genes from Gogelstein et al. (2013)
        vogelstein_genes = [
            gene['hugoSymbol'] for gene in cancer_genes
            if gene.get('vogelstein') and gene['vogelstein'] is True
        ]
        return vogelstein_genes
    except requests.RequestException as e:
        raise RuntimeError(f"Failed to fetch OncoKB data: {e}")


vogelstein_genes = get_vogelstein_genes()
print(f"Found {len(vogelstein_genes)} Vogelstein et al. (2013) genes.")

Found 125 Vogelstein et al. (2013) genes.


In [5]:
def get_somatic_mutations(genes: List[str]) -> pd.DataFrame:
    # Gene list string for SQL
    gene_list = ', '.join([f"'{gene}'" for gene in genes])
    
    query = f"""
            SELECT
                case_barcode AS Barcode,
                Hugo_Symbol AS Gene,
                Variant_Type AS Variant,
                Tumor_Seq_Allele1 AS Reference,
                Tumor_Seq_Allele2 AS Mutation,
                Start_Position AS Start,
                End_Position,
                Chromosome,
                primary_site    
            FROM
                `isb-cgc-bq.TCGA.masked_somatic_mutation_hg38_gdc_current`
            WHERE
                Hugo_Symbol IN ({gene_list})
                AND Variant_Type IN ('SNP', 'DEL', 'INS');
    """
    
    try:
        return client.query(query).to_dataframe()      
    except GoogleAPICallError as e:
        raise RuntimeError(f"BigQuery error: {e}")


mutations_df = get_somatic_mutations(vogelstein_genes)
print(f"Fetched {mutations_df.shape[0]} somatic mutations")



Fetched 48912 somatic mutations


In [6]:
def save_parquet(df: pd.DataFrame,
                 path: str = "file.parquet") -> None:
    df.to_parquet(path, engine='pyarrow')


save_parquet(mutations_df, "tcga_somatic_mutations.parquet")

In [7]:
def get_rna_expression(genes: List[str]) -> pd.DataFrame:
    gene_list = ', '.join([f"'{gene}'" for gene in genes])
    query = f"""
        SELECT
            case_barcode AS Barcode,
            gene_name AS Gene,
            tpm_unstranded AS Tpm,
            fpkm_unstranded AS Fpkm,
            fpkm_uq_unstranded AS QFpkm,
            primary_site
        FROM
            `isb-cgc-bq.TCGA.RNAseq_hg38_gdc_current`
        WHERE
            gene_name IN ({gene_list})
    """
    try:
        job = client.query(query)
        df = job.to_dataframe()

        # Remove duplicates
        df = df.drop_duplicates(subset=['Barcode', 'Gene', 'Tpm', 'Fpkm'])
        return df
    except GoogleAPICallError as e:
        raise RuntimeError(f"BigQuery error: {e}")


expression_df = get_rna_expression(vogelstein_genes)
print(f"Fetched {expression_df.shape[0]} RNA-seq records")

Fetched 1445936 RNA-seq records


In [8]:
save_parquet(expression_df, 'tcga_expression.parquet')

In [9]:
# Read parquet files
df_expressions = pd.read_parquet('/kaggle/working/tcga_expression.parquet')
df_mutations = pd.read_parquet('/kaggle/working/tcga_somatic_mutations.parquet')

In [10]:
# Class imbalance check
display(df_mutations['primary_site'].value_counts())

primary_site
Corpus uteri                                                              10424
Bronchus and lung                                                          5875
Colon                                                                      4881
Skin                                                                       3755
Brain                                                                      3314
Stomach                                                                    3178
Bladder                                                                    2778
Breast                                                                     2737
Kidney                                                                     1373
Cervix uteri                                                               1361
Liver and intrahepatic bile ducts                                          1045
Ovary                                                                       993
Rectum                     

In [11]:
def plot_histogram(df: pd.DataFrame) -> None:
    primary_site_counts = df['primary_site'].value_counts().reset_index()
    primary_site_counts.columns = ['primary_site', 'count']
    fig = px.bar(primary_site_counts, 
             x='primary_site', 
             y='count', 
             log_y=True,
             title='Distribution of Primary Sites (Log Scale)')
    fig.update_layout(xaxis_tickangle=-90)
    fig.show()


plot_histogram(df_mutations)
plot_histogram(df_expressions)

# 🧬 Feature engineering

In [12]:
# Download Human Genome
!wget http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
!gunzip hg38.fa.gz

--2025-06-23 02:10:21--  http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.119.163
Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.119.163|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 983659424 (938M) [application/x-gzip]
Saving to: ‘hg38.fa.gz’


2025-06-23 02:10:49 (34.5 MB/s) - ‘hg38.fa.gz’ saved [983659424/983659424]



In [13]:
genome = Fasta('/kaggle/working/hg38.fa')

In [14]:
def reverse_complement(seq: str) -> str:
    """ Reverse complement of DNA sequence.
    """
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N'}
    return ''.join(complement.get(base, base) for base in seq[::-1])

def complement_base(base: str) -> str:
    """ Complement single base.
    """
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N'}
    return complement.get(base, base)

def normalize_mutation(triplet: str,
                       ref: str,
                       alt: str) -> Tuple[str, str, str]:
    """ Normalize to pyrimidine context.
    """
    if ref in ['A', 'G']:  # Purine
        triplet = reverse_complement(triplet)
        ref = complement_base(ref)
        alt = complement_base(alt)
    return triplet, ref, alt

def get_trinucleotide_context(chrom: str,
                              pos: int,
                              ref: str,
                              genome: Fasta) -> Tuple[str]:
    """ Get trinucleotide context with robust error handling.
    """
    try:
        chrom_name = f"chr{chrom}" if chrom != "MT" else "chrM"
        if chrom_name not in genome:
            return None, None, None
            
        # Get 3bp context
        start_idx = pos - 2
        end_idx = pos + 1
        triplet = genome[chrom_name][start_idx:end_idx].seq.upper()
        
        if len(triplet) != 3:
            return None, None, None
            
        return triplet, ref, triplet[1]
    except:
        return None, None, None

def create_genomic_bins(chrom_sizes: Dict[str, int],
                        bin_size: int = 1000000) -> Dict[str, int]:
    """ Create genomic bins with chromosome handling.
    """
    bins = []
    for chrom, size in chrom_sizes.items():
        # Calculate number of bins
        num_bins = (size // bin_size) + 1
        for i in range(1, num_bins + 1):
            bins.append(f"{chrom}_{i}")
    
    return {bin_id: idx for idx, bin_id in enumerate(bins)}

def predefine_motifs() -> Dict[str, int]:
    """ Generate motif mapping including:
      - 96 SNP trinucleotide contexts (pyrimidine-centered)
      - 64 DEL/INS contexts (flanking base, type, frame shift, flanking base)
    Returns: Dict[motif_str -> index].
    """
    bases = ['A', 'C', 'G', 'T']
    motifs: List[str] = []
    # 96 SNP contexts
    for left in bases:
        for ref in ['C', 'T']:
            for right in bases:
                for alt in bases:
                    if ref != alt:
                        motifs.append(f"{left}{ref}{right}_{ref}>{alt}")

    # 64 DEL/INS contexts
    types = ['DEL', 'INS']
    frames = ['NF', 'FS']  # NF: no frameshift, FS: frameshift
    for left, right, t, frame in itertools.product(bases, bases, types, frames):
        motifs.append(f"{left}_{t}_{frame}_{right}")

    return {motif: idx for idx, motif in enumerate(motifs)}

def get_genomic_base(chrom: str,
                     position: int,
                     genome: Fasta) -> str:
    """ Get single genomic base with error handling.
    """
    try:
        chrom_name = f"chr{chrom}" if chrom != "MT" else "chrM"
        if chrom_name not in genome:
            return "N"
            
        if position < 0 or position >= len(genome[chrom_name]):
            return "N"
            
        return genome[chrom_name][position:position+1].seq.upper()
    except:
        return "N"

def normalize_chrom_name(chrom: str) -> str:
    """ Normalize chromosome names to standard format.
    """
    chrom = str(chrom).upper().replace("CHR", "")
    if chrom == "23": return "X"
    if chrom == "24": return "Y"
    if chrom in ["M", "MT"]: return "MT"
    return chrom

In [15]:
def process_mutations(
    mutations_df: pd.DataFrame,
    genome: Fasta,
    genes: List[str]
    ) -> Tuple[pd.DataFrame, Dict[str, int], Dict[str, int], int]:
    """ Process mutations including SNP, DEL, and INS.
    """
    # Predefine motifs and chromosomes
    motif_map = predefine_motifs()
    chrom_sizes = {}
    chrom_mapping = {}
    
    for chrom in mutations_df['Chromosome'].unique().astype(str):
        chrom_clean = normalize_chrom_name(chrom)
        chrom_mapping[chrom] = chrom_clean  # Store mapping
        
        # Skip if already processed
        if chrom_clean in chrom_sizes:
            continue

        variant = f"chr{chrom_clean}"
        if variant in genome:
            chrom_sizes[chrom_clean] = len(genome[variant])
            print(f"Found {variant} with size {chrom_sizes[chrom_clean]}")
        else:
            print(f"Chromosome {chrom} not found in reference genome")
    
    # Bin mapping
    bin_mapping = create_genomic_bins(chrom_sizes)
    total_bins = len(bin_mapping)
    print(f"Created {total_bins} genomic bins")

    mutations_df['genomic_bin'] = None
    mutations_df['motif'] = None
    
    # Process each mutation
    for i, row in mutations_df.iterrows():
        var = row['Variant']
        orig_chrom = str(row['Chromosome'])
        chrom = chrom_mapping.get(orig_chrom, normalize_chrom_name(orig_chrom))
        
        try:
            if var == 'SNP':
                pos = int(row['Start'])
                ref = row['Reference']
                alt = row['Mutation']
                
                # Validate SNP
                if len(ref) != 1 or len(alt) != 1:
                    continue
                
                # Get trinucleotide context
                triplet, context_ref, center_base = get_trinucleotide_context(chrom,
                                                                              pos,
                                                                              ref,
                                                                              genome
                                                                            )

                if triplet and center_base == ref:
                    norm_triplet, norm_ref, norm_alt = normalize_mutation(triplet, ref, alt)
                    key = f"{norm_triplet}_{norm_ref}>{norm_alt}"
                else:
                    key = None

            elif var in ('DEL', 'INS'):
                start_pos = int(row['Start'])
                ref = row['Reference']
                alt = row['Mutation']
                
                if var == 'DEL':
                    if 'End_Position' in row and not pd.isna(row['End_Position']):
                        end_pos = int(row['End_Position'])
                    else:
                        end_pos = start_pos + len(ref) - 1
                    length = end_pos - start_pos + 1
                else:
                    end_pos = start_pos
                    length = len(alt)

                left_base = get_genomic_base(chrom, start_pos - 1, genome)  # Base before mutation
                right_base = get_genomic_base(chrom, end_pos + 1, genome)    # Base after mutation
                frame = 'FS' if (length % 3) != 0 else 'NF'
                key = f"{left_base}_{var}_{frame}_{right_base}"
                
            else:
                key = None
            
            # Assign genomic bin
            try:
                start_bin_idx = (int(row['Start'] - 1) // 1_000_000) + 1
                end_bin_idx = (int(row['End_Position'] - 1) // 1_000_000) + 1
                list_bins = []
                for idx in range(start_bin_idx, end_bin_idx + 1):
                    list_bins.append(bin_mapping.get(f"{chrom}_{idx}"))
                
                mutations_df.at[i, 'genomic_bin'] = list_bins
            except:
                mutations_df.at[i, 'genomic_bin'] = None
            
            # Assign motif
            mutations_df.at[i, 'motif'] = key

        except Exception as e:
            print(f"Error processing mutation at index {i}: {e}")
            mutations_df.at[i, 'genomic_bin'] = None
            mutations_df.at[i, 'motif'] = None

    return mutations_df, bin_mapping, motif_map, total_bins
   
def process_expression(expression_df: pd.DataFrame,
                       genes: List[str]) -> pd.DataFrame:
    """ Process RNA expression with log1p normalization.
    """
    all_samples = expression_df['Barcode'].unique()
    expression_matrix = pd.DataFrame(
        np.zeros((len(all_samples), len(genes))),
        index=all_samples,
        columns=genes
    )

    for _, row in expression_df.iterrows():
        sample = row['Barcode']
        gene = row['Gene']
        value = row['QFpkm']
        if gene in genes and sample in expression_matrix.index:
            expression_matrix.at[sample, gene] = value

    # log1p normalization
    expression_matrix = np.log1p(expression_matrix + 1e-6)
    
    return expression_matrix

In [16]:
mutations_processed, bin_mapping, motif_to_index, total_bins = process_mutations(
                                                                df_mutations,
                                                                genome,
                                                                vogelstein_genes
                                                            )
    
expression_matrix = process_expression(df_expressions, vogelstein_genes)

Found chr11 with size 135086622
Found chr15 with size 101991189
Found chrX with size 156040895
Found chr16 with size 90338345
Found chr19 with size 58617616
Found chr12 with size 133275309
Found chr1 with size 248956422
Found chr6 with size 170805979
Found chr22 with size 50818468
Found chr7 with size 159345973
Found chr18 with size 80373285
Found chr3 with size 198295559
Found chr14 with size 107043718
Found chr9 with size 138394717
Found chr17 with size 83257441
Found chr2 with size 242193529
Found chr21 with size 46709983
Found chr5 with size 181538259
Found chr4 with size 190214555
Found chr20 with size 64444167
Found chr13 with size 114364328
Found chr10 with size 133797422
Created 2898 genomic bins


In [17]:
print(mutations_processed.head())
print('\n', '*' * 100, '\n')
print(expression_matrix.head())

        Barcode   Gene Variant Reference Mutation      Start  End_Position  \
0  TCGA-17-Z017    CBL     SNP         C        T  119274922     119274922   
1  TCGA-AA-3811   IDH2     DEL         C        -   90088686      90088686   
2  TCGA-AD-5900  GATA1     SNP         C        T   48791282      48791282   
3  TCGA-AJ-A3BH  AXIN1     SNP         C        T     289531        289531   
4  TCGA-B5-A1MR   JAK3     SNP         C        T   17835938      17835938   

  Chromosome       primary_site genomic_bin       motif  
0      chr11  Bronchus and lung       [119]     TCG_C>T  
1      chr15              Colon       [226]  C_DEL_FS_C  
2       chrX              Colon       [286]     GCG_C>T  
3      chr16       Corpus uteri       [395]     CCG_C>T  
4      chr19       Corpus uteri       [503]     GCG_C>T  

 **************************************************************************************************** 

                  ABL1      AKT1       ALK     AMER1       APC        AR  \
TC

In [18]:
SEED = 42
BATCH_SIZE = 64
MOTIF_DIM = mutations_processed['motif'].nunique()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Set {DEVICE} device for computations.")

Set cuda device for computations.


In [19]:
def to_matrix(
    mutations_df: pd.DataFrame,
    expression_matrix: pd.DataFrame,
    genes: List[str],
    bin_mapping: dict,
    motif_to_index: dict,
    total_bins: int) -> pd.DataFrame:
    """ Create concatenated feature vectors.
    """
    bin_features = defaultdict(set)
    gene_features = defaultdict(set)
    motif_features = defaultdict(set)
    
    # Collect mutation features
    for _, row in mutations_df.iterrows():
        sample = row['Barcode']
        gene = row['Gene']
        
        if gene in genes:
            gene_features[sample].add(gene)
        
        if not pd.isna(row['genomic_bin']):
            for bin in row['genomic_bin']:
                bin_features[sample].add(bin)
   
        if not pd.isna(row['motif']):
            motif_features[sample].add(motif_to_index[row['motif']])
    
    all_samples = set(mutations_df['Barcode']) | set(expression_matrix.index)
    sample_features = []
    
    for sample in all_samples:
        bin_vector = [1 if i in bin_features[sample] else 0 
                        for i in range(total_bins)]
        
        gene_vector = [1 if gene in gene_features[sample] else 0 
                        for gene in genes]
        
        motif_vector = [1 if i in motif_features[sample] else 0 
                        for i in range(MOTIF_DIM)]

        rna_vector = expression_matrix.loc[sample][genes].values if sample in expression_matrix.index \
                    else np.zeros(len(genes))
        
        # Concatenate features
        full_vector = np.concatenate([bin_vector, gene_vector, motif_vector, rna_vector])
        sample_features.append((sample, full_vector))
    
    features_df = pd.DataFrame(sample_features, columns=['Barcode', 'features'])
    
    # Add primary_site labels
    primary_sites = mutations_df.groupby('Barcode')['primary_site'].first()
    features_df = features_df.merge(primary_sites, on='Barcode', how='left')
    
    return features_df.dropna(subset=['primary_site'])


features_df = to_matrix(
        mutations_processed,
        expression_matrix,
        vogelstein_genes,
        bin_mapping,
        motif_to_index,
        total_bins
    )

In [20]:
class TCGADataset(Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 n_bins: int,
                 n_genes: int):
        self.features = np.array(df['features'].tolist())
        self.labels = df['primary_site'].values
        
        # Store dimensions
        self.n_bins = n_bins
        self.n_genes = n_genes
        
        # Encode labels
        self.le = LabelEncoder()
        self.labels_encoded = self.le.fit_transform(self.labels)
        self.n_classes = len(self.le.classes_)
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        features = self.features[idx].astype(np.float32)
        X_bin = features[:self.n_bins]
        X_gene = features[self.n_bins:self.n_bins+self.n_genes]
        X_motif = features[self.n_bins+self.n_genes:self.n_bins+self.n_genes+MOTIF_DIM]
        X_rna = features[self.n_bins+self.n_genes+MOTIF_DIM:]
        
        return {
            'genomic_bins': torch.tensor(X_bin),
            'gene_mutations': torch.tensor(X_gene),
            'mutational_motifs': torch.tensor(X_motif),
            'rna_expression': torch.tensor(X_rna)
        }, torch.tensor(self.labels_encoded[idx])

In [21]:
n_genes = len(vogelstein_genes)
dataset = TCGADataset(features_df, total_bins, n_genes)
print(f"Created PyTorch dataset with {len(dataset)} examples for {n_genes} genes.")

Created PyTorch dataset with 8745 examples for 125 genes.


In [22]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [23]:
set_seed(SEED)
labels = dataset.labels_encoded
idx_all = np.arange(len(labels))
unique, counts = np.unique(labels, return_counts=True)
freq = dict(zip(unique, counts))

# Single vs Rest classes
rare_labels = {lab for lab, cnt in freq.items() if cnt < 2}
rare_idx = [i for i, lab in enumerate(labels) if lab in rare_labels]
rest_idx = [i for i, lab in enumerate(labels) if lab not in rare_labels]

# Train/Test = 80/20
train_rest, test_rest = train_test_split(
    rest_idx,
    test_size=0.2,
    stratify=labels[rest_idx],
    random_state=SEED
)

# Train/Val/Test = 70/10/20
train_rest, val_rest = train_test_split(
    train_rest,
    test_size=0.125,
    stratify=labels[train_rest],
    random_state=SEED
)
train_idx = np.concatenate([rare_idx, train_rest])
val_idx   = np.array(val_rest)
test_idx  = np.array(test_rest)

# Dataloaders
train_loader = DataLoader(
    torch.utils.data.Subset(dataset, train_idx),
    batch_size=BATCH_SIZE,
    shuffle=True
)
val_loader = DataLoader(
    torch.utils.data.Subset(dataset, val_idx),
    batch_size=BATCH_SIZE
)
test_loader = DataLoader(
    torch.utils.data.Subset(dataset, test_idx),
    batch_size=BATCH_SIZE
)

# Balanced weighs for unbalanced primary_sites
all_labels = labels[train_idx]
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(all_labels),
    y=all_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

# 🤖 Architecture

In [24]:
class BasicClassifier(nn.Module):
    def __init__(self,
                 bin_dim: int,
                 gene_dim: int,
                 motif_dim: int,
                 rna_dim: int,
                 n_classes: int,
                 size: int):
        super().__init__()

        self.bin_processor = nn.Sequential(
            nn.Linear(bin_dim, size),
            nn.ReLU()
        )
        
        self.gene_processor = nn.Sequential(
            nn.Linear(gene_dim, size),
            nn.ReLU()
        )
        
        self.motif_processor = nn.Sequential(
            nn.Linear(motif_dim, size),
            nn.ReLU()
        )
        
        self.rna_processor = nn.Sequential(
            nn.Linear(rna_dim, size),
            nn.ReLU()
        )
        
        # Combined processing
        self.combined = nn.Sequential(
            nn.Linear(4 * size, 4 * size),
            nn.ReLU(),
            nn.Linear(4 * size, 2 * size),
            nn.ReLU(),
            nn.Linear(2 * size, size),
            nn.ReLU(),
            nn.Linear(size, n_classes)
        )
    
    def forward(self, x):
        # Process each group
        bin_out = self.bin_processor(x['genomic_bins'])
        gene_out = self.gene_processor(x['gene_mutations'])
        motif_out = self.motif_processor(x['mutational_motifs'])
        rna_out = self.rna_processor(x['rna_expression'])
        combined = torch.cat((bin_out, gene_out, motif_out, rna_out), dim=1)
        return self.combined(combined)

In [25]:
def train_model(
    model: nn.Module,
    train_loader: Dataset,
    val_loader: Dataset,
    n_classes: int,
    device: torch.device,
    optimizer=None,
    class_weights=None,
    epochs: int = 100,
    patience: int = 4,
    path: str = "best_model.pth"
) -> dict:
    """ Train the model with early stopping and record metrics
    of validation dataset. Returns a history dict containing
    losses and accuracies.
    """
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=3
    )
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [],   'val_acc': [],
        'val_precision': [], 'val_recall': [],
        'val_f1': []
    }

    best_val_acc = -1
    early_stop_counter = 0

    # Training
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        all_preds, all_labels = [], []
        for inputs, labels in train_loader:
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

        # Metrics
        y_true = torch.cat(all_labels).numpy()
        y_pred = torch.cat(all_preds).numpy()
        train_acc = balanced_accuracy_score(y_true, y_pred)
        train_loss = train_loss / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        val_preds, val_labels = [], []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = {k: v.to(device) for k, v in inputs.items()}
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * labels.size(0)
                preds = outputs.argmax(dim=1)
                val_preds.append(preds.cpu())
                val_labels.append(labels.cpu())

        # Compute validation metrics
        y_true_val = torch.cat(val_labels).numpy()
        y_pred_val = torch.cat(val_preds).numpy()
        mask = np.isin(y_pred_val, y_true_val) # otherwise warning
        val_acc = balanced_accuracy_score(y_true_val[mask], y_pred_val[mask])
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true_val, y_pred_val, average='weighted', zero_division=0
        )
        val_loss = val_loss / len(val_loader.dataset)

        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_precision'].append(precision)
        history['val_recall'].append(recall)
        history['val_f1'].append(f1)
        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, \n"
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

        # Scheduler and early stopping
        scheduler.step(val_loss)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            early_stop_counter = 0
            torch.save(model.state_dict(), path)
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered!")
                break

    # Load best model
    model.load_state_dict(torch.load("/kaggle/working/" + path))
    history['model'] = model
    return history


In [26]:
def plot_training_history(history: Dict[str, float]) -> None:
    """ Plot loss and accuracy.
    """
    # Prepare data
    epochs = list(range(1, len(history['train_loss']) + 1))
    df_loss = {
        'epoch': epochs * 2,
        'loss': history['train_loss'] + history['val_loss'],
        'set': ['train'] * len(epochs) + ['validation'] * len(epochs)
    }
    df_acc = {
        'epoch': epochs * 2,
        'accuracy': history['train_acc'] + history['val_acc'],
        'set': ['train'] * len(epochs) + ['validation'] * len(epochs)
    }

    # Plot Loss
    fig_loss = px.line(
        df_loss, x='epoch', y='loss', color='set',
        title='Training and Validation Loss'
    )
    fig_loss.show()

    # Plot Accuracy
    fig_acc = px.line(
        df_acc, x='epoch', y='accuracy', color='set',
        title='Training and Validation Accuracy'
    )
    fig_acc.show()


In [27]:
set_seed(SEED)

model = BasicClassifier(
        bin_dim=total_bins,
        gene_dim=n_genes,
        motif_dim=MOTIF_DIM,
        rna_dim=n_genes,
        n_classes=dataset.n_classes,
        size=256,
).to(DEVICE)

basic_result = train_model(model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    n_classes=dataset.n_classes,
                    device=DEVICE,
                    class_weights=class_weights,
                    epochs=128,
                    patience=32
        )
basic_model = basic_result['model']

Epoch 1/128 | Train Loss: 3.7208, Acc: 0.0732 | Val Loss: 3.3033, Acc: 0.1152, 
Precision: 0.0926, Recall: 0.1259, F1-score: 0.0871
Epoch 2/128 | Train Loss: 3.1311, Acc: 0.1640 | Val Loss: 2.5762, Acc: 0.2509, 
Precision: 0.4266, Recall: 0.3421, F1-score: 0.3035
Epoch 3/128 | Train Loss: 2.4607, Acc: 0.2524 | Val Loss: 2.1850, Acc: 0.3738, 
Precision: 0.6125, Recall: 0.5160, F1-score: 0.5136
Epoch 4/128 | Train Loss: 2.0048, Acc: 0.3398 | Val Loss: 1.9433, Acc: 0.4722, 
Precision: 0.6895, Recall: 0.5641, F1-score: 0.5744
Epoch 5/128 | Train Loss: 1.6932, Acc: 0.3884 | Val Loss: 1.9526, Acc: 0.4466, 
Precision: 0.7353, Recall: 0.5664, F1-score: 0.5973
Epoch 6/128 | Train Loss: 1.4884, Acc: 0.4409 | Val Loss: 1.7170, Acc: 0.5152, 
Precision: 0.7218, Recall: 0.6590, F1-score: 0.6698
Epoch 7/128 | Train Loss: 1.3335, Acc: 0.4966 | Val Loss: 1.9880, Acc: 0.5062, 
Precision: 0.7566, Recall: 0.6018, F1-score: 0.6122
Epoch 8/128 | Train Loss: 1.1718, Acc: 0.5558 | Val Loss: 2.1282, Acc: 0.472

# 🎯 Assessing overfitting

In [28]:
plot_training_history(basic_result)

# 👾 Enhancing the model

In [29]:
class AdvancedClassifier(nn.Module):
    def __init__(
        self,
        bin_dim: int,
        gene_dim: int,
        motif_dim: int,
        rna_dim: int,
        n_classes: int,
        hidden_sizes: Dict[str, int],  # {'bin': 2,'gene': 2,'motif': 2,'rna': 2}
        dropout: float = 0.5,
        use_batchnorm: bool = True,
        combined_sizes: List[int] = None
    ):
        """
        Args:
            bin_dim: Dimensionality of genomic bin features
            gene_dim: Dimensionality of gene mutation features
            motif_dim: Dimensionality of mutational motif features
            rna_dim: Dimensionality of RNA expression features
            n_classes: Number of output classes
            hidden_sizes: Dict mapping feature names to hidden dims
            dropout: Dropout probability
            use_batchnorm: Whether to include BatchNorm1d after
                           each Linear layer.
            combined_sizes: list of ints for concatenated layer dims
        """
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.dropout = dropout

        def block(in_dim: int, out_dim: int) -> nn.Sequential:
            layers = [nn.Linear(in_dim, out_dim)]
            if use_batchnorm:
                layers.append(nn.BatchNorm1d(out_dim))
            layers.extend([nn.ReLU(), nn.Dropout(dropout)])
            return nn.Sequential(*layers)

        # Dense layers
        self.bin_processor = block(bin_dim, hidden_sizes['bin'])
        self.gene_processor = block(gene_dim, hidden_sizes['gene'])
        self.motif_processor = block(motif_dim, hidden_sizes['motif'])
        self.rna_processor = block(rna_dim, hidden_sizes['rna'])

        # Combined feature
        concat_dim = sum(hidden_sizes.values())
        layers = []
        prev_dim = concat_dim

        if combined_sizes is None:
            combined_sizes = [concat_dim]
        for size in combined_sizes:
            layers.append(block(prev_dim, size))
            prev_dim = size

        layers.append(nn.Linear(prev_dim, n_classes))
        self.combined = nn.Sequential(*layers)

    def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
        bin_out = self.bin_processor(x['genomic_bins'])
        gene_out = self.gene_processor(x['gene_mutations'])
        motif_out = self.motif_processor(x['mutational_motifs'])
        rna_out = self.rna_processor(x['rna_expression'])
        combined = torch.cat([bin_out, gene_out, motif_out, rna_out], dim=1)
        return self.combined(combined)


def build_model_and_optimizer(
    config: Dict,
    device: torch.device) -> Dict[str, object]:
    """ Builds model using config dict with layer sizes.
    """
    model = AdvancedClassifier(
        bin_dim=config['bin_dim'],
        gene_dim=config['gene_dim'],
        motif_dim=config['motif_dim'],
        rna_dim=config['rna_dim'],
        n_classes=config['n_classes'],
        hidden_sizes=config.get('hidden_sizes', {
            'bin': 64,
            'gene': 64,
            'motif': 64,
            'rna': 64
        }),
        dropout=config.get('dropout', 0.5),
        use_batchnorm=config.get('batchnorm', True),
        combined_sizes=config.get('combined_sizes', [256, 128])
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.get('lr', 1e-3),
        weight_decay=config.get('l2_decay', 1e-4)
    )

    return {'model': model, 'optimizer': optimizer}

In [62]:
# Advanced model
set_seed(SEED)

config = {
    'bin_dim': total_bins,
    'gene_dim': n_genes,
    'motif_dim': MOTIF_DIM,
    'rna_dim': n_genes,
    'n_classes': dataset.n_classes,
    'hidden_sizes': {'bin': 128, 'gene': 128, 'motif':128, 'rna': 2048},
    'combined_sizes': [1024, 256, 256],
    'dropout': 0.25,
    'batchnorm': True,
    'lr': 8e-5,
    'l2_decay': 2e-5
}

advanced_dict = build_model_and_optimizer(config, DEVICE)
advanced_history = train_model(model=advanced_dict['model'],
                               train_loader=train_loader,
                               val_loader=val_loader,
                               optimizer=advanced_dict['optimizer'],
                               n_classes=dataset.n_classes,
                               device=DEVICE,
                               class_weights=class_weights,
                               epochs=128,
                               patience=32,
                               path="best_advanced_model.pth"
)
advanced_model = advanced_history['model']

Epoch 1/128 | Train Loss: 3.5895, Acc: 0.1231 | Val Loss: 3.0020, Acc: 0.4846, 
Precision: 0.7024, Recall: 0.6281, F1-score: 0.6421
Epoch 2/128 | Train Loss: 2.9235, Acc: 0.3221 | Val Loss: 2.5061, Acc: 0.5655, 
Precision: 0.7863, Recall: 0.7117, F1-score: 0.7264
Epoch 3/128 | Train Loss: 2.4901, Acc: 0.4131 | Val Loss: 2.1769, Acc: 0.6084, 
Precision: 0.8281, Recall: 0.7597, F1-score: 0.7769
Epoch 4/128 | Train Loss: 2.2248, Acc: 0.4666 | Val Loss: 1.9743, Acc: 0.6580, 
Precision: 0.8471, Recall: 0.7609, F1-score: 0.7912
Epoch 5/128 | Train Loss: 1.9828, Acc: 0.5486 | Val Loss: 1.7706, Acc: 0.6430, 
Precision: 0.8414, Recall: 0.7826, F1-score: 0.7949
Epoch 6/128 | Train Loss: 1.7599, Acc: 0.5794 | Val Loss: 1.6223, Acc: 0.6652, 
Precision: 0.8644, Recall: 0.7838, F1-score: 0.8124
Epoch 7/128 | Train Loss: 1.6260, Acc: 0.6048 | Val Loss: 1.5387, Acc: 0.6739, 
Precision: 0.8543, Recall: 0.7780, F1-score: 0.8056
Epoch 8/128 | Train Loss: 1.4457, Acc: 0.6799 | Val Loss: 1.4690, Acc: 0.682

In [63]:
plot_training_history(advanced_history)

In [64]:
def test_model(model: nn.Module,
               path: str) -> None:
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            labels = labels.to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.append(predicted.cpu())
            all_labels.append(labels.cpu())

    # Compute metrics
    y_true = torch.cat(all_labels).numpy()
    y_pred = torch.cat(all_preds).numpy()
    accuracy = balanced_accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted', zero_division=0
    )

    print(f"Test Balanced Accuracy:  {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall:    {recall:.4f}")
    print(f"Test F1-score:  {f1:.4f}")

    # Per-class report
    print("\nClassification Report:")
    print(classification_report(
        y_true,
        y_pred,
        digits=4,
        zero_division=0
    ))

    # Save final model
    torch.save(model.state_dict(), path)


In [65]:
test_model(model, "final_basic_model.pth")
test_model(advanced_model, "final_advanced_model.pth")
print("That's all folks!")

Test Balanced Accuracy:  0.5108
Test Precision: 0.7749
Test Recall:    0.7168
Test F1-score:  0.7328

Classification Report:
              precision    recall  f1-score   support

           0     0.7200    0.8182    0.7660        22
           1     0.0000    0.0000    0.0000         4
           2     0.6778    0.7722    0.7219        79
           5     0.9686    0.9390    0.9536       164
           6     0.9684    0.8693    0.9162       176
           7     0.8125    0.5174    0.6322       201
           8     0.5538    0.6923    0.6154        52
           9     0.5106    0.2857    0.3664        84
          11     0.3939    0.7647    0.5200        17
          12     0.9429    0.6408    0.7630       103
          13     0.5484    0.4857    0.5152        35
          14     0.7500    1.0000    0.8571        15
          15     0.3333    0.0909    0.1429        11
          17     0.1111    0.5000    0.1818         2
          18     0.7778    0.5000    0.6087        14
          