In [1]:
import pandas as pd
import numpy as np
from datasets import Dataset
from collections import Counter
import gc
import time
import psutil
import os

start_time = time.time()
process = psutil.Process(os.getpid())

def print_memory_usage(step_name):
    mem_info = process.memory_info()
    mem_gb = mem_info.rss / (1024 ** 3)

    print(f"{step_name} - Memory Usage: {mem_gb:.2f} GB")
    return mem_gb

def print_step_header(step_num, step_name):
    separator = "=" * 60

    print("\n" + separator)
    print(f"STEP {step_num}: {step_name}")
    print(separator)

initial_mem = print_memory_usage("Initial")

print("\nSTARTING PROTEIN GO TERM PROCESSING")
print(f"Start Time: {time.strftime('%H:%M:%S')}")
print(f"Initial Memory: {initial_mem:.2f} GB")

  from .autonotebook import tqdm as notebook_tqdm


Initial - Memory Usage: 0.28 GB

STARTING PROTEIN GO TERM PROCESSING
Start Time: 11:24:53
Initial Memory: 0.28 GB


In [2]:
from Bio import SeqIO

tsv_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/raw/train/uniprotkb_AND_reviewed_true_AND_protein_2025_12_27.tsv"
labels_df = pd.read_csv(tsv_path, sep='\t')

fasta_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/raw/train/uniprotkb_AND_reviewed_true_AND_protein_2025_12_27.fasta"
sequence_dict = {}

for record in SeqIO.parse(fasta_path, "fasta"):

    header = record.id
    if "|" in header:
        protein_id = header.split("|")[1]
    else:
        protein_id = header.split()[0]
    sequence_dict[protein_id] = str(record.seq)

print(f"TSV entries: {len(labels_df)}")
print(f"FASTA sequences: {len(sequence_dict)}")

TSV entries: 105951
FASTA sequences: 105951


In [3]:
def parse_go_terms(go_string):

    if pd.isna(go_string) or go_string == "":
        return []

    return [term.strip() for term in str(go_string).split(';')]

labels_df['go_terms_list'] = labels_df['Gene Ontology IDs'].apply(parse_go_terms)

In [4]:
labels_df.head()

Unnamed: 0,Entry,Entry Name,Protein names,Organism,Sequence,Gene Ontology IDs,go_terms_list
0,A0A009IHW8,ABTIR_ACIB9,2' cyclic ADP-D-ribose synthase AbTIR (2'cADPR...,Acinetobacter baumannii (strain 1295743),MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,GO:0003953; GO:0007165; GO:0019677; GO:0050135...,"[GO:0003953, GO:0007165, GO:0019677, GO:005013..."
1,A0A023I7E1,ENG1_RHIMI,"Glucan endo-1,3-beta-D-glucosidase 1 (Endo-1,3...",Rhizomucor miehei,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,GO:0000272; GO:0005576; GO:0042973; GO:0052861...,"[GO:0000272, GO:0005576, GO:0042973, GO:005286..."
2,A0A024B7W1,POLG_ZIKVF,Genome polyprotein [Cleaved into: Capsid prote...,Zika virus (isolate ZIKV/Human/French Polynesi...,MKNPKKKSGGFRIVNMLKRGVARVSPFGGLKRLPAGLLLGHGPIRM...,GO:0003724; GO:0003725; GO:0003968; GO:0004252...,"[GO:0003724, GO:0003725, GO:0003968, GO:000425..."
3,A0A024RXP8,GUX1_HYPJR,"Exoglucanase 1 (EC 3.2.1.91) (1,4-beta-cellobi...",Hypocrea jecorina (strain ATCC 56765 / BCRC 32...,MYRKLAVISAFLATARAQSACTLQSETHPPLTWQKCSSGGTCTQQT...,GO:0005576; GO:0016162; GO:0030245; GO:0030248,"[GO:0005576, GO:0016162, GO:0030245, GO:0030248]"
4,A0A024SC78,CUTI1_HYPJR,Cutinase (EC 3.1.1.74),Hypocrea jecorina (strain ATCC 56765 / BCRC 32...,MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,GO:0005576; GO:0016052; GO:0050525,"[GO:0005576, GO:0016052, GO:0050525]"


In [5]:
filtered_df = labels_df[labels_df['Entry'].isin(sequence_dict.keys())].copy()

filtered_df['sequence'] = filtered_df['Entry'].map(sequence_dict)

train_df = filtered_df[['Entry', 'sequence', 'go_terms_list', 'Organism']].rename(
    columns={
        'Entry': 'accession',
        'go_terms_list': 'go_terms',
        'Organism': 'organism'
    }
)

print(f"Matched proteins: {len(train_df)}")

output_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/training_data_combined.csv"
train_df.to_csv(output_path, index=False)
print(f"Saved to: {output_path}")

Matched proteins: 105951
Saved to: C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/training_data_combined.csv


In [6]:
train_df.head()

Unnamed: 0,accession,sequence,go_terms,organism
0,A0A009IHW8,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,"[GO:0003953, GO:0007165, GO:0019677, GO:005013...",Acinetobacter baumannii (strain 1295743)
1,A0A023I7E1,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,"[GO:0000272, GO:0005576, GO:0042973, GO:005286...",Rhizomucor miehei
2,A0A024B7W1,MKNPKKKSGGFRIVNMLKRGVARVSPFGGLKRLPAGLLLGHGPIRM...,"[GO:0003724, GO:0003725, GO:0003968, GO:000425...",Zika virus (isolate ZIKV/Human/French Polynesi...
3,A0A024RXP8,MYRKLAVISAFLATARAQSACTLQSETHPPLTWQKCSSGGTCTQQT...,"[GO:0005576, GO:0016162, GO:0030245, GO:0030248]",Hypocrea jecorina (strain ATCC 56765 / BCRC 32...
4,A0A024SC78,MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,"[GO:0005576, GO:0016052, GO:0050525]",Hypocrea jecorina (strain ATCC 56765 / BCRC 32...


In [7]:
print(f"Total proteins: {len(train_df)}")

print("\nCollecting unique GO terms and computing frequencies...")
go_counts = Counter(term for go_list in train_df['go_terms'] for term in go_list)
all_go_terms = set(go_counts.keys())

print(f"Found {len(all_go_terms)} unique GO terms total")
mem1 = print_memory_usage("After collecting all terms")

ultra_general = {'GO:0008150', 'GO:0005575', 'GO:0003674'}

filtered_terms = [
    term for term, count in go_counts.most_common()
    if term not in ultra_general
]

top_n_frequent = 2000
top_go_frequent = filtered_terms[:top_n_frequent]

remaining_terms = filtered_terms[top_n_frequent:]

remaining_filtered = [
    term for term in remaining_terms
    if go_counts[term] >= 5
]
np.random.seed(42)

top_go_random = (
    np.random.choice(remaining_filtered, 761, replace=False).tolist()
    if len(remaining_filtered) >= 761
    else remaining_filtered
)

top_go_terms = top_go_frequent + top_go_random

print(f"\nSelected {len(top_go_frequent)} frequent + "
      f"{len(top_go_random)} random = {len(top_go_terms)} total GO terms")

print(f"Most common term: {top_go_frequent[0]} "
      f"(appears {go_counts[top_go_frequent[0]]} times)")

if top_go_random:
    print(f"Sample random term: {top_go_random[0]} "
          f"(appears {go_counts[top_go_random[0]]} times)")

mem2 = print_memory_usage("After term selection")

print(f"Time elapsed: {time.time() - start_time:.1f} seconds")

Total proteins: 105951

Collecting unique GO terms and computing frequencies...
Found 27615 unique GO terms total
After collecting all terms - Memory Usage: 0.54 GB

Selected 2000 frequent + 761 random = 2761 total GO terms
Most common term: GO:0005737 (appears 22662 times)
Sample random term: GO:0016311 (appears 62 times)
After term selection - Memory Usage: 0.54 GB
Time elapsed: 8.7 seconds


In [8]:
go_to_index = {go: idx for idx, go in enumerate(top_go_terms)}

num_proteins = len(train_df)
num_labels = len(top_go_terms)

print(f"Total proteins: {num_proteins:,}")
print(f"Total selected GO terms: {num_labels:,}")

binary_matrix = np.zeros((num_proteins, num_labels), dtype=np.int8)
fill_start_time = time.time()
for protein_idx, go_list in enumerate(train_df['go_terms']):
    for go in go_list:
        if go in go_to_index:
            label_idx = go_to_index[go]

            binary_matrix[protein_idx, label_idx] = 1
column_names = [f"label_{go}" for go in top_go_terms]

binary_df = pd.DataFrame(binary_matrix, columns=column_names)

print(f"Created binary DataFrame with {binary_df.shape[1]:,} label columns")

del binary_matrix
gc.collect()
mem3 = print_memory_usage("After binary matrix creation")

print(f"Matrix filling time: {time.time() - fill_start_time:.1f} seconds")
print(f"Total time elapsed: {time.time() - start_time:.1f} seconds")

Total proteins: 105,951
Total selected GO terms: 2,761
Created binary DataFrame with 2,761 label columns
After binary matrix creation - Memory Usage: 0.50 GB
Matrix filling time: 2.4 seconds
Total time elapsed: 11.1 seconds


In [9]:
final_df = pd.concat(
    [
        train_df[['sequence', 'accession']],  # original features
        binary_df                              # binary label matrix
    ], axis=1
)
print(f"Final DataFrame shape: {final_df.shape[0]:,} proteins √ó {final_df.shape[1]:,} columns")

del binary_df
gc.collect()
print("\nVERIFICATION:")

example_accession = final_df.iloc[0]['accession']
example_sequence = final_df.iloc[0]['sequence']

print(f"Example protein: {example_accession}")
print(f"Sequence length: {len(example_sequence):,}")

label_columns = [col for col in final_df.columns if col.startswith('label_')]
print(f"Number of GO term labels: {len(label_columns):,}")

positive_counts = final_df.loc[0, label_columns].sum()
print(f"Positive labels for first protein: {positive_counts} out of {len(label_columns)}")
empty_labels = [col for col in label_columns if final_df[col].sum() == 0]

if empty_labels:
    print(f"Warning: {len(empty_labels)} labels have no positive examples!")
else:
    print("All labels have at least one positive example")

mem4 = print_memory_usage("After final dataframe")
print(f"Total time elapsed: {time.time() - start_time:.1f} seconds")

Final DataFrame shape: 105,951 proteins √ó 2,763 columns

VERIFICATION:
Example protein: A0A009IHW8
Sequence length: 269
Number of GO term labels: 2,761
Positive labels for first protein: 2 out of 2761
All labels have at least one positive example
After final dataframe - Memory Usage: 0.57 GB
Total time elapsed: 16.6 seconds


In [10]:
hf_start_time = time.time()

dataset = Dataset.from_pandas(final_df, preserve_index=False)
print(f"Dataset created with {len(dataset):,} proteins")

label_columns = [col for col in dataset.column_names if col.startswith('label_')]
print(f"Number of label columns: {len(label_columns):,}")

print(f"Conversion time: {time.time() - hf_start_time:.1f} seconds")

save_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/protein_go_dataset"
print(f"Saving HuggingFace Dataset to: {save_path}")

dataset.save_to_disk(save_path)
print("Dataset saved successfully.")

metadata = pd.DataFrame({
    'go_term': top_go_terms,
    'count': [go_counts[term] for term in top_go_terms],
    'percentage': [(go_counts[term] / len(final_df)) * 100 for term in top_go_terms],
    'is_frequent': [True] * len(top_go_frequent) + [False] * len(top_go_random)
})

metadata_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/go_terms_metadata.csv"
metadata.to_csv(metadata_path, index=False)
print(f"Saved metadata to: {metadata_path}")

total_time = time.time() - start_time
final_mem = print_memory_usage("Final")

print("\nPERFORMANCE SUMMARY:")
print(f"Total processing time: {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
print(f"Memory increase: {final_mem - initial_mem:.2f} GB")
print(f"Time: {time.strftime('%H:%M:%S')}")

print("\nDATA SUMMARY:")
print(f"  ‚Ä¢ Input proteins: {len(train_df):,}")
print(f"  ‚Ä¢ Unique GO terms (original): {len(all_go_terms):,}")
print(f"  ‚Ä¢ Selected GO terms: {len(top_go_terms):,}")
print(f"  ‚Ä¢ Final dataset size: {final_df.shape[0]:,} √ó {final_df.shape[1]:,}")
print(f"  ‚Ä¢ Label representation: multi-hot binary (int8)")

print("\nAll steps completed successfully!")

Dataset created with 105,951 proteins
Number of label columns: 2,761
Conversion time: 24.9 seconds
Saving HuggingFace Dataset to: C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/protein_go_dataset


Saving the dataset (1/1 shards): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 105951/105951 [23:51<00:00, 73.99 examples/s]

Dataset saved successfully.
Saved metadata to: C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/go_terms_metadata.csv
Final - Memory Usage: 0.59 GB

PERFORMANCE SUMMARY:
Total processing time: 1473.7 seconds (24.6 minutes)
Memory increase: 0.30 GB
Time: 11:49:27

DATA SUMMARY:
  ‚Ä¢ Input proteins: 105,951
  ‚Ä¢ Unique GO terms (original): 27,615
  ‚Ä¢ Selected GO terms: 2,761
  ‚Ä¢ Final dataset size: 105,951 √ó 2,763
  ‚Ä¢ Label representation: multi-hot binary (int8)

All steps completed successfully!





In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
from datasets import load_from_disk

print("‚úì All libraries imported successfully")
print_memory_usage("After imports")

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    print(f"‚úì Random seed set to: {seed}")

set_seed(42)
print_memory_usage("After setting seeds")

print("Loading dataset...")
mem_before = print_memory_usage("Before loading")
dataset_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/protein_go_dataset"
dataset = load_from_disk(dataset_path)
print(f"‚úì Loaded: {len(dataset):,} proteins")

mem_after = print_memory_usage("After loading")
print(f"‚úì Memory used: {mem_after - mem_before:.3f} GB")

label_columns = [col for col in dataset.column_names if col.startswith("label_")]
print(f"‚úì Number of GO term labels: {len(label_columns):,}")

dataset.set_format(
    type="torch",
    columns=label_columns  # labels as tensors
)
print("‚úì Dataset formatted for PyTorch (labels as tensors)")

‚úì All libraries imported successfully
After imports - Memory Usage: 0.63 GB
‚úì Random seed set to: 42
After setting seeds - Memory Usage: 0.63 GB
Loading dataset...
Before loading - Memory Usage: 0.63 GB
‚úì Loaded: 105,951 proteins
After loading - Memory Usage: 0.73 GB
‚úì Memory used: 0.107 GB
‚úì Number of GO term labels: 2,761
‚úì Dataset formatted for PyTorch (labels as tensors)


In [12]:
dataset_path = "C:/Users/USER/Documents/cod3astro/ML_AI/ProteinSeq_DL/data/processed/protein_go_dataset"
dataset = load_from_disk(dataset_path)

print("‚úì Dataset loaded")
print(f"Total proteins: {len(dataset)}")

seq_lengths = [len(seq) for seq in dataset["sequence"]]
seq_lengths = np.array(seq_lengths)

print("\nüìä Sequence Length Statistics")
print(f"Min length: {seq_lengths.min()}")
print(f"Max length: {seq_lengths.max()}")
print(f"Mean length: {seq_lengths.mean():.2f}")
print(f"Median length: {np.median(seq_lengths)}")
print(f"Std deviation: {seq_lengths.std():.2f}")

‚úì Dataset loaded
Total proteins: 105951

üìä Sequence Length Statistics
Min length: 2
Max length: 35213
Mean length: 472.26
Median length: 358.0
Std deviation: 535.95


In [13]:
bins = list(range(0, 2000, 250))  # 0,250,500,...,2000
bins.append(float("inf"))         # >2000 bin

hist, bin_edges = np.histogram(seq_lengths, bins=bins)

print("\nüì¶ Sequence Length Distribution")
for i in range(len(hist)):
    start = int(bin_edges[i])
    end = int(bin_edges[i+1]) if bin_edges[i+1] != float("inf") else "‚àû"
    print(f"{start:>4} - {end:<4} : {hist[i]} proteins")


üì¶ Sequence Length Distribution
   0 - 250  : 35261 proteins
 250 - 500  : 37285 proteins
 500 - 750  : 17379 proteins
 750 - 1000 : 7145 proteins
1000 - 1250 : 3509 proteins
1250 - 1500 : 1878 proteins
1500 - 1750 : 956 proteins
1750 - ‚àû    : 2538 proteins


In [14]:
print("\nüîç Checking for duplicate sequences...")

unique_sequences = set(dataset["sequence"]) 
total_sequences = len(dataset) 
unique_count = len(unique_sequences) 
duplicates = total_sequences - unique_count

print(f"Total sequences: {total_sequences}")
print(f"Unique sequences: {unique_count}")
print(f"Duplicate sequences: {duplicates}")


üîç Checking for duplicate sequences...
Total sequences: 105951
Unique sequences: 102750
Duplicate sequences: 3201


In [15]:
print("üîç Removing duplicate sequences (Arrow-safe method)...")

original_size = len(dataset)
print(f"Original dataset size: {original_size}")

sequences = dataset["sequence"]

seen = {}
keep_indices = []
for idx, seq in enumerate(sequences):
    if seq not in seen:
        seen[seq] = True
        keep_indices.append(idx)

print(f"Keeping {len(keep_indices)} unique sequences")
dataset = dataset.select(keep_indices)

new_size = len(dataset)
print(f"New dataset size: {new_size}")
print(f"Removed duplicates: {original_size - new_size}")

üîç Removing duplicate sequences (Arrow-safe method)...
Original dataset size: 105951
Keeping 102750 unique sequences
New dataset size: 102750
Removed duplicates: 3201


In [16]:
class ProteinDataset(Dataset):

    def __init__(self, hf_dataset, max_seq_length=512):
        self.dataset = hf_dataset
        self.max_seq_length = max_seq_length

        self.label_cols = [
            col for col in hf_dataset.column_names
            if col.startswith("label_")
        ]

        self.amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        self.aa_to_idx = {
            aa: i + 1 for i, aa in enumerate(self.amino_acids)
        }
        print(f"  Found {len(self.label_cols)} GO term labels")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sequence = self.dataset[idx]["sequence"]
        encoded_seq = self.encode_sequence(sequence)

        labels = [
            self.dataset[idx][col]
            for col in self.label_cols
        ]
        return {
            "sequence": torch.tensor(encoded_seq, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.float32)
        }
    
    def encode_sequence(self, sequence):
        sequence = sequence[:self.max_seq_length]
        encoded = [
            self.aa_to_idx.get(aa, 0)  # 0 if unknown
            for aa in sequence
        ]

        padding_length = self.max_seq_length - len(encoded)
        if padding_length > 0:
            encoded += [0] * padding_length

        return encoded

print("‚úì ProteinDataset class created successfully")
print_memory_usage("After creating dataset class")

‚úì ProteinDataset class created successfully
After creating dataset class - Memory Usage: 0.82 GB


0.8194694519042969

In [17]:
mem_before = print_memory_usage("Before splitting data")

split_1 = dataset.train_test_split(
    test_size=0.3,
    seed=42
)

train_dataset_arrow = split_1["train"]
temp_dataset_arrow = split_1["test"]

split_2 = temp_dataset_arrow.train_test_split(
    test_size=0.5,
    seed=42
)
val_dataset_arrow = split_2["train"]
test_dataset_arrow = split_2["test"]

train_dataset = ProteinDataset(train_dataset_arrow, max_seq_length=512)
val_dataset = ProteinDataset(val_dataset_arrow, max_seq_length=512)
test_dataset = ProteinDataset(test_dataset_arrow, max_seq_length=512)

print("‚úì Data split completed:")
print(f"  Training samples: {len(train_dataset):,}")
print(f"  Validation samples: {len(val_dataset):,}")
print(f"  Test samples: {len(test_dataset):,}")

mem_after = print_memory_usage("After splitting data")
print(f"  Memory used by splits: {mem_after - mem_before:.2f} GB")

gc.collect()
print("  Garbage collection performed")
print_memory_usage("After garbage collection")

Before splitting data - Memory Usage: 0.82 GB
  Found 2761 GO term labels
  Found 2761 GO term labels
  Found 2761 GO term labels
‚úì Data split completed:
  Training samples: 71,925
  Validation samples: 15,412
  Test samples: 15,413
After splitting data - Memory Usage: 0.84 GB
  Memory used by splits: 0.02 GB
  Garbage collection performed
After garbage collection - Memory Usage: 0.84 GB


0.839019775390625

In [18]:
class SimpleProteinCNN(nn.Module):

    def __init__(self, num_classes):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=21,
            embedding_dim=64,
            padding_idx=0
        )

        self.conv_layers = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv1d(128, 64, kernel_size=5, padding=2),
            nn.ReLU(),

            nn.Conv1d(64, 32, kernel_size=7, padding=3),
            nn.ReLU()
        )

        self.global_pool = nn.AdaptiveMaxPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        x = self.conv_layers(x)
        x = self.global_pool(x)
        x = x.squeeze(-1)

        logits = self.classifier(x)
        return logits

print("‚úì SimpleProteinCNN class created")

sample_model = SimpleProteinCNN(num_classes=len(train_dataset.label_cols))
total_params = sum(p.numel() for p in sample_model.parameters())

print(f"  Model parameters: {total_params:,}")
print_memory_usage("After creating model")

‚úì SimpleProteinCNN class created


  Model parameters: 273,385
After creating model - Memory Usage: 0.84 GB


0.8419723510742188

In [None]:
print("Step 1: Setting up device...")

device = torch.device('cpu')
print("‚úì Using CPU (most reliable for laptops)")
print("\nStep 2: Checking our data...")
print(f"We have {len(train_dataset)} training proteins")

if hasattr(train_dataset, 'label_cols'):
    print(f"Found {len(train_dataset.label_cols)} GO term labels")
else:
    print("Warning: No label columns found!")
    num_dummy_labels = 100
    print(f"Creating {num_dummy_labels} dummy labels for testing")

print("\nStep 3: Creating data loaders...")

batch_size = 8 

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=0  
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f"‚úì Created data loaders")
print(f"  Batch size: {batch_size}")
print(f"  Training batches: {len(train_loader)}")
print("\nCreating a simpler model...")

class VerySimpleProteinCNN(nn.Module):
    
    def __init__(self, num_classes=100):
        super(VerySimpleProteinCNN, self).__init__()
        
        self.embedding = nn.Embedding(21, 32, padding_idx=0)
        
        self.conv = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        
        self.pool = nn.AdaptiveMaxPool1d(1)
        
        self.fc = nn.Linear(64, num_classes)
        
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        
        x = self.relu(self.conv(x))
        
        x = self.pool(x)
        x = x.squeeze(-1)
        
        x = self.fc(x)
        return x

try:
    num_classes = 100  
    model = VerySimpleProteinCNN(num_classes=num_classes)
    model = model.to(device)
    print(f"‚úì Created VerySimpleProteinCNN")
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  Parameters: {total_params:,}")
    
except Exception as e:
    print(f"Error creating model: {e}")
    print("Creating the simplest possible model...")
    
    class LinearModel(nn.Module):
        def __init__(self, num_classes=100):
            super(LinearModel, self).__init__()
            self.fc = nn.Linear(256, num_classes)  # Sequence length is 256
            
        def forward(self, x):
            return self.fc(x.float().mean(dim=1))
    
    model = LinearModel(num_classes=100)
    model = model.to(device)
    print("‚úì Created LinearModel (simplest possible)")

print("\nStep 5: Setting up optimizer (with workaround)...")

import torch
torch.cuda.empty_cache() if torch.cuda.is_available() else None

criterion = nn.BCEWithLogitsLoss()
print(f"‚úì Created loss function")

optimizers_to_try = [
    ('SGD', optim.SGD(model.parameters(), lr=0.01)),
    ('RMSprop', optim.RMSprop(model.parameters(), lr=0.001)),
    ('Adagrad', optim.Adagrad(model.parameters(), lr=0.01))
]

try:
    print("Trying Adam optimizer...")
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
    print("‚úì Adam optimizer created successfully!")
    
except Exception as e:
    print(f"Adam failed: {e}")
    print("Trying SGD instead...")
    
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    print("‚úì Using SGD optimizer instead")

print("\nStep 6: Testing everything works...")

try:
    test_batch = next(iter(train_loader))
    sequences = test_batch['sequence'].to(device)
    labels = test_batch['labels'].to(device)
    
    print(f"‚úì Got batch: sequences={sequences.shape}, labels={labels.shape}")
    
    with torch.no_grad():  # Don't calculate gradients for testing
        outputs = model(sequences)
        print(f"‚úì Model can make predictions: outputs={outputs.shape}")
    
    loss = criterion(outputs, labels)
    print(f"‚úì Can calculate loss: {loss.item():.4f}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"‚úì Can do training step (backpropagation)")
    
    print("\n ALL TESTS PASSED! Ready for training.")
    
except Exception as e:
    print(f"‚ùå Test failed: {e}")
    print("\nDebugging info:")
    
    print(f"  Device: {device}")
    print(f"  Model type: {type(model)}")
    print(f"  Model on device? {next(model.parameters()).device}")
    
    print("\nTrying to fix...")
    
    model = VerySimpleProteinCNN(num_classes=100)
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    print("‚úì Reset model and optimizer")

print("\n" + "="*50)
print("FINAL TRAINING SETUP:")
print("="*50)

print(f"1. Device: {device}")
print(f"2. Model: {model.__class__.__name__}")
print(f"3. Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"4. Optimizer: {optimizer.__class__.__name__}")
print(f"5. Batch size: {batch_size}")
print(f"6. Training samples: {len(train_dataset)}")
print(f"7. Validation samples: {len(val_dataset)}")

print_memory_usage("After setup")
print("READY FOR TRAINING!")


STEP 14: PREPARING FOR TRAINING - SIMPLE VERSION
Step 1: Setting up device...
‚úì Using CPU (most reliable for laptops)

Step 2: Checking our data...
We have 71925 training proteins
Found 2761 GO term labels

Step 3: Creating data loaders...
‚úì Created data loaders
  Batch size: 8
  Training batches: 8991

Creating a simpler model...
‚úì Created VerySimpleProteinCNN
  Parameters: 13,380

Step 5: Setting up optimizer (with workaround)...
‚úì Created loss function
Trying Adam optimizer...
‚úì Adam optimizer created successfully!

Step 6: Testing everything works...
‚úì Got batch: sequences=torch.Size([512]), labels=torch.Size([512, 8])
‚ùå Test failed: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3

Debugging info:
  Device: cpu
  Model type: <class '__main__.VerySimpleProteinCNN'>
  Model on device? cpu

Trying to fix...
‚úì Reset model and optimizer

FINAL

In [20]:
print_step_header(14, "TRAINING - SIMPLE VERSION")

# ============================================
# PART 1: TRAINING SETUP
# ============================================

print("Starting training...")
print(f"Training samples: {len(train_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Total batches per epoch: {len(train_loader)}")

# Create lists to track progress
train_losses = []
val_losses = []

# How many times to go through all data
num_epochs = 3  
best_loss = float('inf')  

# ============================================
# PART 2: THE TRAINING LOOP
# ============================================
for epoch in range(num_epochs):
    print(f"\nüìÖ Epoch {epoch + 1}/{num_epochs}")
    print("-" * 30)
    
    # ===== TRAINING PHASE =====
    model.train() 
    total_train_loss = 0
    
    for batch_idx, batch in enumerate(train_loader):
        # Get data
        sequences = batch['sequence'].to(device)
        labels = batch['labels'].to(device)
        
        # Clear old gradients
        optimizer.zero_grad()
        
        # Make predictions
        outputs = model(sequences)
        
        # Calculate loss 
        loss = criterion(outputs, labels)
        
        # Learn from mistakes
        loss.backward()
        optimizer.step()
        
        # Track loss
        total_train_loss += loss.item()
        
        # Show progress every 1000 batches
        if (batch_idx + 1) % 1000 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # Calculate average training loss for this epoch
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # ===== VALIDATION PHASE =====
    model.eval()  # Set model to evaluation mode
    total_val_loss = 0
    
    with torch.no_grad():  # No learning during validation
        for batch in val_loader:
            sequences = batch['sequence'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    # ===== SHOW RESULTS =====
    print(f"\nüìä Results for Epoch {epoch + 1}:")
    print(f"  Training Loss: {avg_train_loss:.4f}")
    print(f"  Validation Loss: {avg_val_loss:.4f}")
    
    # ===== SAVE BEST MODEL =====
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"  ‚úÖ Saved better model! (Loss improved to {avg_val_loss:.4f})")
    else:
        print(f"  Model did not improve")

# ============================================
# PART 3: TRAINING COMPLETE
# ============================================

print("\n" + "="*50)
print("TRAINING COMPLETE!")
print("="*50)
print(f"‚úì Best model saved as: best_model.pth")
print(f"‚úì Final training loss: {train_losses[-1]:.4f}")
print(f"‚úì Final validation loss: {val_losses[-1]:.4f}")

# ============================================
# PART 4: QUICK SUMMARY
# ============================================

print("\nüìà Training Summary:")
print(f"  Epochs completed: {num_epochs}")
print(f"  Training loss started at: {train_losses[0]:.4f}")
print(f"  Training loss ended at: {train_losses[-1]:.4f}")
if len(train_losses) > 1:
    improvement = train_losses[0] - train_losses[-1]
    print(f"  Improvement: {improvement:.4f}")



STEP 14: TRAINING - SIMPLE VERSION
Starting training...
Training samples: 71925
Batch size: 8
Total batches per epoch: 8991

üìÖ Epoch 1/3
------------------------------


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3