In [29]:
# !pip install torch==2.2.2 torchtext==0.17.2 nltk
# !pip install torchdata==0.7.1
# !pip install pyarrow

In [30]:
import torch
import os
import pickle
from model import ClassificationNet, train_epoch, evaluate_epoch, save_list_to_file, load_list_from_file

In [16]:
from dataloader import get_dataloaders, DEVICE

print("Imported successfully!")
print(f"Running on device: {DEVICE}")

BATCH_SIZE = 64
train_dataloader, valid_dataloader, test_dataloader, vocab = get_dataloaders(batch_size=BATCH_SIZE)

# --- Verification Step ---
print("\nVerifying by fetching one batch from train_dataloader...")
labels, texts = next(iter(train_dataloader))

print(f"Labels batch shape: {labels.shape}")
print(f"Texts batch shape: {texts.shape}")

Imported successfully!
Running on device: cuda
Loading data from Parquet files...
Vocabulary Size: 95811

Verifying by fetching one batch from train_dataloader...
Labels batch shape: torch.Size([64])
Texts batch shape: torch.Size([64, 74])


In [17]:
vocab_size = len(vocab)
num_classes = 4

In [18]:
EXPERIMENT_DIR = "runs/adam_from_epoch51"
METRICS_DIR = os.path.join(EXPERIMENT_DIR, "metrics") 
os.makedirs(EXPERIMENT_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True) 

NEW_CHECKPOINT_PATH = os.path.join(EXPERIMENT_DIR, 'checkpoint.pth')

print(f"Experiment artifacts will be saved in: {EXPERIMENT_DIR}")
print(f"Metrics will be saved in: {METRICS_DIR}")

Experiment artifacts will be saved in: runs/adam_from_epoch51
Metrics will be saved in: runs/adam_from_epoch51/metrics


In [19]:
# Initializing model and optimizer
model = ClassificationNet(vocab_size=vocab_size, num_class=num_classes).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 
criterion = torch.nn.CrossEntropyLoss()
# Decay the learning rate by a factor of 10 every 8 epochs.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

In [20]:
start_epoch = 1
best_val_accuracy = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

In [21]:
BASE_CHECKPOINT_PATH = 'runs/initial/checkpoint_epoch_50.pth'

In [22]:
if os.path.exists(BASE_CHECKPOINT_PATH):
    print(f"Found base checkpoint. Initializing model with weights from '{BASE_CHECKPOINT_PATH}'...")
    checkpoint = torch.load(BASE_CHECKPOINT_PATH)
    model.load_state_dict(checkpoint['model_state_dict']) # Load weights only
    start_epoch = checkpoint['epoch'] + 1
    history = checkpoint['history']                       # Continue history
    best_val_accuracy = checkpoint['best_val_accuracy']
else:
    print("No checkpoints found. Starting a completely new training run.")

Found base checkpoint. Initializing model with weights from 'runs/initial/checkpoint_epoch_50.pth'...


In [32]:
EPOCHS = 70            # Train up to a total of 70 epochs

print(f"Starting training from epoch {start_epoch}...")

for epoch in range(start_epoch, EPOCHS + 1):
    print("-" * 50)
    print(f"Epoch {epoch}/{EPOCHS}")
    
    # Run training and validation
    train_loss, train_acc = train_epoch(model, train_dataloader, criterion, optimizer, DEVICE)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    
    val_loss, val_acc = evaluate_epoch(model, valid_dataloader, criterion, DEVICE)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"\nEpoch Summary:")
    print(f"\tTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"\tValid Loss: {val_loss:.4f} | Valid Acc: {val_acc*100:.2f}%")

    # Save best model to the experiment folder
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        best_model_path = os.path.join(EXPERIMENT_DIR, 'best_model.pth')
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved to '{best_model_path}'")

print("-" * 50)
print("Training Finished!")
        
        
# Save final history to the experiment folder
for key, value in history.items():
    file_path = os.path.join(METRICS_DIR, f"{key}.pkl")
    save_list_to_file(value, file_path)

Data successfully saved to: runs/adam_from_epoch51/metrics/train_loss.pkl
Data successfully saved to: runs/adam_from_epoch51/metrics/train_acc.pkl
Data successfully saved to: runs/adam_from_epoch51/metrics/val_loss.pkl
Data successfully saved to: runs/adam_from_epoch51/metrics/val_acc.pkl


In [33]:
# Save the checkpoint for this run 
current_checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'best_val_accuracy': best_val_accuracy
}
torch.save(current_checkpoint, NEW_CHECKPOINT_PATH)