In [9]:
import matplotlib.pyplot as plt
import os
import torch 
from torch.utils.data import DataLoader
import json
import sys
import torch

sys.path.append("src")
os.chdir('/home/george-vengrovski/Documents/projects/tweety_bert_paper')

from data_class import CollateFunction
from utils import load_model

weights_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-MSE_LLB3_10_Mask/saved_weights/model_step_15800.pth"
config_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-MSE_LLB3_10_Mask/config.json"

tweety_bert_model = load_model(config_path, weights_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Data Class

In [10]:
from torch.utils.data import DataLoader
from data_class import SongDataSet_Image

train_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_train"
test_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_test"

train_dataset = SongDataSet_Image(train_dir, num_classes=196, psuedo_labels_generated=True)
test_dataset = SongDataSet_Image(test_dir, num_classes=196, psuedo_labels_generated=True)

collate_fn = CollateFunction(segment_length=1000)  # Adjust the segment length if needed

train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)

## Define Linear Classifier and Train

In [11]:
from linear_probe import LinearProbeModel, LinearProbeTrainer

classifier_model = LinearProbeModel(num_classes=196, model_type="neural_net", model=tweety_bert_model, freeze_layers=True, layer_num=-3, layer_id="attention_output", classifier_dims=196)
classifier_model = classifier_model.to(device)

In [12]:
trainer = LinearProbeTrainer(model=classifier_model, train_loader=train_loader, test_loader=test_loader, device=device, lr=1e-3, plotting=True, batches_per_eval=1, desired_total_batches=1e4, patience=4)
trainer.train()

Batch 1: FER = 84.17%, Train Loss = 5.4033, Val Loss = 4.4738
Batch 2: FER = 38.31%, Train Loss = 4.5192, Val Loss = 3.6220
Batch 3: FER = 27.66%, Train Loss = 3.6444, Val Loss = 2.8906
Batch 4: FER = 20.74%, Train Loss = 2.8567, Val Loss = 2.2174
Batch 5: FER = 16.59%, Train Loss = 2.1952, Val Loss = 1.6526
Batch 6: FER = 17.49%, Train Loss = 1.7201, Val Loss = 1.3121
Batch 7: FER = 16.64%, Train Loss = 1.4019, Val Loss = 1.0467
Batch 8: FER = 16.40%, Train Loss = 1.0639, Val Loss = 0.8319
Batch 9: FER = 17.60%, Train Loss = 0.8741, Val Loss = 0.7500
Batch 10: FER = 16.65%, Train Loss = 0.7710, Val Loss = 0.6407
Batch 11: FER = 16.95%, Train Loss = 0.6789, Val Loss = 0.6031
Batch 12: FER = 14.56%, Train Loss = 0.5562, Val Loss = 0.5396
Batch 13: FER = 16.54%, Train Loss = 0.5926, Val Loss = 0.5947
Batch 14: FER = 14.25%, Train Loss = 0.4939, Val Loss = 0.5050
Batch 15: FER = 12.69%, Train Loss = 0.4966, Val Loss = 0.4415
Batch 16: FER = 12.09%, Train Loss = 0.5628, Val Loss = 0.4276
B

KeyboardInterrupt: 

## Analyze

In [None]:
from linear_probe import ModelEvaluator

evaluator = ModelEvaluator(classifier_model, test_loader)
class_frame_error_rates, total_frame_error_rate = evaluator.validate_model_multiple_passes(num_passes=1, max_batches=1250)
evaluator.save_results(class_frame_error_rates, total_frame_error_rate, '/home/george-vengrovski/Documents/projects/tweety_bert_paper/results/test')