In [1]:
import matplotlib.pyplot as plt
import os
import torch 
from torch.utils.data import DataLoader
import json
import sys
import torch

sys.path.append("src")
os.chdir('/home/george-vengrovski/Documents/projects/tweety_bert_paper')

from data_class import CollateFunction
from utils import load_model

weights_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-cluster_LLB3_10_Mask_50_clusters/saved_weights/model_step_8600.pth"
config_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-cluster_LLB3_10_Mask_50_clusters/config.json"

tweety_bert_model = load_model(config_path, weights_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Data Class

In [3]:
from torch.utils.data import DataLoader
from data_class import SongDataSet_Image

train_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_train_50"
test_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_test_50"

train_dataset = SongDataSet_Image(train_dir, num_classes=21, psuedo_labels_generated=False)
test_dataset = SongDataSet_Image(test_dir, num_classes=21, psuedo_labels_generated=False)

collate_fn = CollateFunction(segment_length=1000)  # Adjust the segment length if needed

train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)

## Define Linear Classifier and Train

In [4]:
from linear_probe import LinearProbeModel, LinearProbeTrainer

classifier_model = LinearProbeModel(num_classes=21, model_type="neural_net", model=tweety_bert_model, freeze_layers=True, layer_num=-3, layer_id="attention_output", classifier_dims=196)
classifier_model = classifier_model.to(device)

In [5]:
trainer = LinearProbeTrainer(model=classifier_model, train_loader=train_loader, test_loader=test_loader, device=device, lr=1e-3, plotting=True, batches_per_eval=1, desired_total_batches=1e4, patience=4)
trainer.train()

  from .autonotebook import tqdm as notebook_tqdm


Batch 1: FER = 57.14%, Train Loss = 3.1360, Val Loss = 2.3618
Batch 2: FER = 34.88%, Train Loss = 2.2790, Val Loss = 1.7033
Batch 3: FER = 30.90%, Train Loss = 1.6679, Val Loss = 1.3300
Batch 4: FER = 24.63%, Train Loss = 1.3690, Val Loss = 1.0885
Batch 5: FER = 23.03%, Train Loss = 1.2085, Val Loss = 0.9812
Batch 6: FER = 21.49%, Train Loss = 0.9500, Val Loss = 0.8732
Batch 7: FER = 20.13%, Train Loss = 0.9660, Val Loss = 0.7711
Batch 8: FER = 19.38%, Train Loss = 0.8020, Val Loss = 0.7279
Batch 9: FER = 16.43%, Train Loss = 0.7715, Val Loss = 0.6492
Batch 10: FER = 17.02%, Train Loss = 0.6638, Val Loss = 0.6344
Batch 11: FER = 16.51%, Train Loss = 0.7864, Val Loss = 0.6087
Batch 12: FER = 16.79%, Train Loss = 0.6356, Val Loss = 0.6527
Batch 13: FER = 15.67%, Train Loss = 0.5061, Val Loss = 0.6090
Batch 14: FER = 14.13%, Train Loss = 0.6564, Val Loss = 0.5546
Batch 15: FER = 14.41%, Train Loss = 0.5017, Val Loss = 0.5660
Batch 16: FER = 13.88%, Train Loss = 0.4797, Val Loss = 0.5186
B

KeyboardInterrupt: 

## Analyze

In [None]:
from linear_probe import ModelEvaluator

evaluator = ModelEvaluator(classifier_model, test_loader)
class_frame_error_rates, total_frame_error_rate = evaluator.validate_model_multiple_passes(num_passes=1, max_batches=1250)
evaluator.save_results(class_frame_error_rates, total_frame_error_rate, '/home/george-vengrovski/Documents/projects/tweety_bert_paper/results/test')