In [1]:
%pip install biopython transformers torch datasets numpy scikit-learn evaluate

Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install transformers[torch]

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments
)
import numpy as np
import os
from Bio import SeqIO
import evaluate
import inspect

print("✅ Imports complete")

  from .autonotebook import tqdm as notebook_tqdm


✅ Imports complete


In [4]:
# Clear GPU cache before starting
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"✅ GPU cache cleared")
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   Available memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1e9:.2f} GB")

✅ GPU cache cleared
   GPU: NVIDIA GeForce RTX 4090
   Total GPU memory: 25.25 GB
   Available memory: 25.25 GB


In [5]:
tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
print("✅ Tokenizer loaded")

✅ Tokenizer loaded


In [6]:
class FASTADataset(Dataset):
    def __init__(self, fasta_file, tokenizer, max_length=512):
        print(f"Loading sequences from {fasta_file}...")
        self.sequences = []
        for record in SeqIO.parse(fasta_file, "fasta"):
            seq = str(record.seq)
            if len(seq) > 0:
                self.sequences.append(seq)
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        print(f"Loaded {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        spaced_seq = " ".join(list(seq))
        
        encoding = self.tokenizer(
            spaced_seq,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        
        result = {k: v.squeeze(0) for k, v in encoding.items()}
        
        if 'attention_mask' not in result:
            result['attention_mask'] = torch.ones_like(result['input_ids'])
        
        return result

print("✅ FASTADataset class defined")

✅ FASTADataset class defined


In [7]:
data_folder = '/home/mluser/AFML_RISHABH/Project/10k sequences'

train_path = os.path.join(data_folder, "kinases_cluster_train_10k.fasta")
val_path   = os.path.join(data_folder, "kinases_cluster_val_10k.fasta")
test_path  = os.path.join(data_folder, "kinases_cluster_test_10k.fasta")

# Verify files exist
for path in [train_path, val_path, test_path]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    print(f"✓ Found: {path}")

# Use reduced max_length to save memory
MAX_LENGTH = 512  # Reduced from 512

train_dataset = FASTADataset(train_path, tokenizer, max_length=MAX_LENGTH)
val_dataset   = FASTADataset(val_path, tokenizer, max_length=MAX_LENGTH)
test_dataset  = FASTADataset(test_path, tokenizer, max_length=MAX_LENGTH)

print("\n" + "="*50)
print(f"Dataset sizes: {len(train_dataset)}, {len(val_dataset)}, {len(test_dataset)}")
print(f"Max sequence length: {MAX_LENGTH}")
print("="*50)

# Test dataset
print("\nTesting dataset[0]...")
sample = train_dataset[0]
print("✓ Sample retrieved successfully")
print(f"  Keys: {sample.keys()}")
print(f"  input_ids shape: {sample['input_ids'].shape}")
print(f"  First 20 tokens: {sample['input_ids'][:20].tolist()}")

✓ Found: /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_train_10k.fasta
✓ Found: /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_val_10k.fasta
✓ Found: /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_test_10k.fasta
Loading sequences from /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_train_10k.fasta...
Loaded 7989 sequences
Loading sequences from /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_val_10k.fasta...
Loaded 1002 sequences
Loading sequences from /home/mluser/AFML_RISHABH/Project/10k sequences/kinases_cluster_test_10k.fasta...
Loaded 1009 sequences

Dataset sizes: 7989, 1002, 1009
Max sequence length: 512

Testing dataset[0]...
✓ Sample retrieved successfully
  Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
  input_ids shape: torch.Size([512])
  First 20 tokens: [4, 94, 123, 100, 153, 266, 101, 66, 129, 66, 66, 66, 66, 66, 66, 66, 101, 70, 101, 66]


In [8]:
from transformers.models.fnet.modeling_fnet import FNetBasicFourierTransform

model = AutoModelForMaskedLM.from_pretrained("google/fnet-base")

# Patch FNet Fourier Transform for float32
class FNetSafeFourierTransform(FNetBasicFourierTransform):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, hidden_states):
        hidden_states = hidden_states.to(torch.float32)
        outputs = torch.fft.fftn(hidden_states, dim=(-2, -1)).real
        return (outputs,)

model.fourier_transform = FNetSafeFourierTransform(model.config)

# Force FP32 and handle unexpected kwargs
def force_fp32_forward(original_forward):
    sig = inspect.signature(original_forward)
    
    def wrapped_forward(*args, **kwargs):
        valid_params = set(sig.parameters.keys())
        filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
        
        with torch.autocast(device_type='cuda', enabled=False):
            return original_forward(*args, **filtered_kwargs)
    
    return wrapped_forward

model.forward = force_fp32_forward(model.forward)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).to(torch.float32)
torch.set_float32_matmul_precision("high")

# Enable gradient checkpointing to save memory
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()
    print("✓ Gradient checkpointing enabled")

print(f"✅ Model ready on {device}")

# Check memory usage
if torch.cuda.is_available():
    print(f"   GPU memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"   GPU memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

✓ Gradient checkpointing enabled
✅ Model ready on cuda
   GPU memory allocated: 0.33 GB
   GPU memory reserved: 0.37 GB


  _C._set_float32_matmul_precision(precision)


In [9]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

# Test the data collator
print("Testing data collator...")
batch = [train_dataset[i] for i in range(2)]
collated = data_collator(batch)
print(f"✓ Collated batch keys: {collated.keys()}")
print(f"  input_ids shape: {collated['input_ids'].shape}")
print(f"  labels shape: {collated['labels'].shape}")

Testing data collator...
✓ Collated batch keys: KeysView({'input_ids': tensor([[  4,  94,   6,  ..., 101, 101,   5],
        [  4,  94, 123,  ..., 101, 164,   5]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[-100, -100,  123,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]])})
  input_ids shape: torch.Size([2, 512])
  labels shape: torch.Size([2, 512])


In [10]:
training_args = TrainingArguments(
    output_dir="./KinaseFNet_10k_512",
    overwrite_output_dir=True,
    num_train_epochs=20,
    per_device_train_batch_size=1,  # Batch size 1
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Keep effective batch size = 8
    save_steps=5000,
    save_total_limit=1,
    logging_steps=100,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=False,
    bf16=False,
    eval_strategy="no",  # No evaluation during training
    dataloader_num_workers=0,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    optim="adamw_torch",
    max_grad_norm=1.0,
    logging_first_step=True,
)

In [11]:
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    mask = labels != -100
    preds = np.argmax(logits, axis=-1)
    preds = preds[mask]
    labels = labels[mask]
    return accuracy_metric.compute(predictions=preds, references=labels)

print("✅ Metrics configured")

✅ Metrics configured


In [12]:
from contextlib import nullcontext

# Clear any stale accelerator state
try:
    from accelerate.state import AcceleratorState
    if hasattr(AcceleratorState, '_shared_state') and AcceleratorState._shared_state:
        AcceleratorState._reset_state()
        print("✓ Cleared accelerator state")
except Exception as e:
    print(f"Note: Could not clear accelerator state: {e}")

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("✓ Cleared GPU cache")

# Create trainer WITHOUT eval_dataset to avoid evaluation during training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    # eval_dataset removed - we'll evaluate manually later
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Disable autocast
trainer.autocast_smart_context_manager = nullcontext

print("✅ Trainer created successfully (no evaluation during training)")

# Quick test
print("\nTesting trainer dataloader...")
try:
    train_dataloader = trainer.get_train_dataloader()
    test_batch = next(iter(train_dataloader))
    print(f"✓ Dataloader test passed")
    print(f"  Batch input_ids shape: {test_batch['input_ids'].shape}")
    
    if torch.cuda.is_available():
        print(f"  GPU memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
except Exception as e:
    print(f"✗ Dataloader test failed: {e}")
    import traceback
    traceback.print_exc()

✓ Cleared GPU cache
✅ Trainer created successfully (no evaluation during training)

Testing trainer dataloader...
✓ Dataloader test passed
  Batch input_ids shape: torch.Size([1, 512])
  GPU memory: 0.33 GB


  trainer = Trainer(


In [13]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()


In [14]:
print("Starting training...")
print("="*60)
print("Memory optimization settings:")
print(f"  - Batch size: 2")
print(f"  - Gradient accumulation: 4 steps")
print(f"  - Max sequence length: 256")
print(f"  - Gradient checkpointing: Enabled")
print("="*60)

try:
    trainer.train()
    print("="*60)
    print("✅ Training completed!")
except RuntimeError as e:
    if "out of memory" in str(e):
        print("\n❌ CUDA Out of Memory Error!")
        print("Try further reducing:")
        print("  1. per_device_train_batch_size to 1")
        print("  2. max_length to 128")
        print("  3. Or use a smaller model")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    raise

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': None, 'bos_token_id': None}.


Starting training...
Memory optimization settings:
  - Batch size: 2
  - Gradient accumulation: 4 steps
  - Max sequence length: 256
  - Gradient checkpointing: Enabled


Step,Training Loss
1,31.1124
100,21.1186
200,20.6634
300,20.5576
400,20.3531
500,20.1862
600,19.9671
700,19.8778
800,19.6199
900,19.545


✅ Training completed!


In [18]:
print("\n" + "="*60)
print("SAVING MODEL (BEFORE EVALUATION)")
print("="*60)

# Save the trained model immediately
output_dir = "./KinaseFNet_10k_512"

print(f"\nSaving model to {output_dir}...")
trainer.save_model(output_dir)
print("✓ Model saved")

print(f"\nSaving tokenizer to {output_dir}...")
tokenizer.save_pretrained(output_dir)
print("✓ Tokenizer saved")

print(f"\nSaving training arguments...")
torch.save(training_args, os.path.join(output_dir, "training_args.bin"))
print("✓ Training args saved")

print("\n" + "="*60)
print(f"✅ MODEL SAFELY SAVED TO: {output_dir}")
print("="*60)

# Verify files were saved
import os
saved_files = os.listdir(output_dir)
print(f"\nSaved files ({len(saved_files)}):")
for f in sorted(saved_files)[:10]:  # Show first 10 files
    print(f"  - {f}")
if len(saved_files) > 10:
    print(f"  ... and {len(saved_files) - 10} more files")

print("\n✅ Your model is now safely saved!")
print("   You can load it later with:")
print(f'   model = AutoModelForMaskedLM.from_pretrained("{output_dir}")')
print(f'   tokenizer = AutoTokenizer.from_pretrained("{output_dir}")')


SAVING MODEL (BEFORE EVALUATION)

Saving model to ./KinaseFNet_10k_512...
✓ Model saved

Saving tokenizer to ./KinaseFNet_10k_512...
✓ Tokenizer saved

Saving training arguments...
✓ Training args saved

✅ MODEL SAFELY SAVED TO: ./KinaseFNet_10k_512

Saved files (7):
  - checkpoint-19980
  - config.json
  - model.safetensors
  - special_tokens_map.json
  - tokenizer.json
  - tokenizer_config.json
  - training_args.bin

✅ Your model is now safely saved!
   You can load it later with:
   model = AutoModelForMaskedLM.from_pretrained("./KinaseFNet_10k_512")
   tokenizer = AutoTokenizer.from_pretrained("./KinaseFNet_10k_512")


In [19]:
# %%
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling
import math
from tqdm import tqdm
import inspect

# ============================================================
# LOAD MODEL AND TOKENIZER
# ============================================================
output_dir = "./KinaseFNet_10k_512"

print(f"Loading trained model from {output_dir}...")
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model = AutoModelForMaskedLM.from_pretrained(output_dir)
print("✓ Model and tokenizer loaded successfully")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

# ============================================================
# PATCH MODEL FOR SAFETY
# ============================================================
def safe_forward(original_forward):
    sig = inspect.signature(original_forward)
    valid_keys = set(sig.parameters.keys())
    def wrapped_forward(*args, **kwargs):
        filtered = {k: v for k, v in kwargs.items() if k in valid_keys}
        return original_forward(*args, **filtered)
    return wrapped_forward

model.forward = safe_forward(model.forward)
print("✓ Model forward patched to ignore unsupported arguments")

# ============================================================
# DATA COLLATOR (for masking during evaluation)
# ============================================================
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

test_loader = DataLoader(test_dataset, batch_size=4, collate_fn=data_collator)

print("\n============================================================")
print("EVALUATING MASKED LANGUAGE MODEL (with dynamic masking)")
print("============================================================")

total_loss = 0.0
total_count = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        if loss is not None:
            total_loss += loss.item() * batch["input_ids"].size(0)
            total_count += batch["input_ids"].size(0)

# ============================================================
# RESULTS
# ============================================================
if total_count > 0:
    avg_loss = total_loss / total_count
    perplexity = math.exp(avg_loss)
    print("\n✅ Evaluation Complete")
    print("------------------------------------------------------------")
    print(f"Average Test Loss : {avg_loss:.4f}")
    print(f"Perplexity         : {perplexity:.2f}")
    print("------------------------------------------------------------")
else:
    print("⚠️ No valid batches returned loss — check masking or dataset.")


Loading trained model from ./KinaseFNet_10k_512...
✓ Model and tokenizer loaded successfully
✓ Model forward patched to ignore unsupported arguments

EVALUATING MASKED LANGUAGE MODEL (with dynamic masking)


Evaluating: 100%|██████████| 253/253 [00:03<00:00, 71.09it/s]


✅ Evaluation Complete
------------------------------------------------------------
Average Test Loss : 2.2356
Perplexity         : 9.35
------------------------------------------------------------





In [20]:
# %%
import torch
from torch.utils.data import ConcatDataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling
import math
from tqdm import tqdm
import inspect

# ============================================================
# LOAD MODEL AND TOKENIZER
# ============================================================
output_dir = "./KinaseFNet_10k_512"

print(f"Loading trained model from {output_dir}...")
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model = AutoModelForMaskedLM.from_pretrained(output_dir)
print("✓ Model and tokenizer loaded successfully")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

# ============================================================
# PATCH MODEL FOR SAFETY (ignore unsupported kwargs)
# ============================================================
def safe_forward(original_forward):
    sig = inspect.signature(original_forward)
    valid_keys = set(sig.parameters.keys())
    def wrapped_forward(*args, **kwargs):
        filtered = {k: v for k, v in kwargs.items() if k in valid_keys}
        return original_forward(*args, **filtered)
    return wrapped_forward

model.forward = safe_forward(model.forward)
print("✓ Model forward patched to ignore unsupported arguments")

# ============================================================
# MERGE DATASETS (train + validation + test)
# ============================================================
full_dataset = ConcatDataset([train_dataset, val_dataset, test_dataset])
print(f"\nTotal samples in full dataset: {len(full_dataset):,}")

# ============================================================
# DATA COLLATOR (dynamic masking)
# ============================================================
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

# Slightly larger batch for faster evaluation if GPU allows
test_loader = DataLoader(full_dataset, batch_size=8, collate_fn=data_collator)

print("\n============================================================")
print("EVALUATING MASKED LANGUAGE MODEL ON FULL DATASET")
print("============================================================")

total_loss = 0.0
total_count = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        if loss is not None:
            total_loss += loss.item() * batch["input_ids"].size(0)
            total_count += batch["input_ids"].size(0)

# ============================================================
# RESULTS
# ============================================================
if total_count > 0:
    avg_loss = total_loss / total_count
    perplexity = math.exp(avg_loss)
    print("\n✅ Evaluation Complete")
    print("------------------------------------------------------------")
    print(f"Average Full Dataset Loss : {avg_loss:.4f}")
    print(f"Perplexity                : {perplexity:.2f}")
    print("------------------------------------------------------------")
else:
    print("⚠️ No valid batches returned loss — check dataset or collator.")


Loading trained model from ./KinaseFNet_10k_512...
✓ Model and tokenizer loaded successfully
✓ Model forward patched to ignore unsupported arguments

Total samples in full dataset: 10,000

EVALUATING MASKED LANGUAGE MODEL ON FULL DATASET


Evaluating: 100%|██████████| 1250/1250 [00:36<00:00, 33.79it/s]


✅ Evaluation Complete
------------------------------------------------------------
Average Full Dataset Loss : 1.9496
Perplexity                : 7.03
------------------------------------------------------------



