In [1]:
import matplotlib.pyplot as plt
import os
import torch 
from torch.utils.data import DataLoader
import json
import sys
import torch

sys.path.append("src")
os.chdir('/home/george-vengrovski/Documents/projects/tweety_bert_paper')

from data_class import CollateFunction
from utils import load_model

weights_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-MSE-Mask-Before-50-mask-alpha-1/saved_weights/model_step_6400.pth"
config_path = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/experiments/TweetyBERT-MSE-Mask-Before-50-mask-alpha-1/config.json"

tweety_bert_model = load_model(config_path, weights_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Data Class

In [2]:
from torch.utils.data import DataLoader
from data_class import SongDataSet_Image

train_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_train"
test_dir = "/home/george-vengrovski/Documents/projects/tweety_bert_paper/files/llb3_test"

train_dataset = SongDataSet_Image(train_dir, num_classes=196)
test_dataset = SongDataSet_Image(test_dir, num_classes=196)

collate_fn = CollateFunction(segment_length=1000)  # Adjust the segment length if needed

train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=48, shuffle=True, collate_fn=collate_fn)

## Define Linear Classifier and Train

In [3]:
from linear_probe import LinearProbeModel, LinearProbeTrainer

classifier_model = LinearProbeModel(num_classes=196, model_type="neural_net", model=tweety_bert_model, freeze_layers=False, layer_num=2, layer_id="attention_output", classifier_dims=196)
classifier_model = classifier_model.to(device)

In [4]:
trainer = LinearProbeTrainer(model=classifier_model, train_loader=train_loader, test_loader=test_loader, device=device, lr=3e-3, plotting=True, batches_per_eval=10, desired_total_batches=1e4, patience=4)
trainer.train()

  from .autonotebook import tqdm as notebook_tqdm


Batch 10: FER = 48.55%, Train Loss = 13.7251, Val Loss = 13.1249
Batch 20: FER = 37.28%, Train Loss = 6.1086, Val Loss = 4.4253
Batch 30: FER = 31.16%, Train Loss = 3.8919, Val Loss = 3.6376
Batch 40: FER = 27.34%, Train Loss = 1.4030, Val Loss = 1.1725
Batch 50: FER = 25.84%, Train Loss = 1.9968, Val Loss = 1.9372
Batch 60: FER = 23.21%, Train Loss = 1.5552, Val Loss = 1.4037
Batch 70: FER = 18.09%, Train Loss = 0.9616, Val Loss = 0.7399
Batch 80: FER = 15.82%, Train Loss = 0.7849, Val Loss = 0.5544
Batch 90: FER = 12.12%, Train Loss = 0.3294, Val Loss = 0.4133
Batch 100: FER = 9.32%, Train Loss = 0.3885, Val Loss = 0.3309
Batch 110: FER = 8.30%, Train Loss = 0.3475, Val Loss = 0.2998
Batch 120: FER = 7.97%, Train Loss = 0.2756, Val Loss = 0.2796
Batch 130: FER = 6.96%, Train Loss = 0.2989, Val Loss = 0.2368
Batch 140: FER = 6.83%, Train Loss = 0.2104, Val Loss = 0.2203
Batch 150: FER = 6.58%, Train Loss = 0.1948, Val Loss = 0.2165
Batch 160: FER = 6.27%, Train Loss = 0.2254, Val Loss

KeyboardInterrupt: 

## Analyze

In [5]:
from linear_probe import ModelEvaluator

evaluator = ModelEvaluator(classifier_model, test_loader)
class_frame_error_rates, total_frame_error_rate = evaluator.validate_model_multiple_passes(num_passes=1, max_batches=1250)
evaluator.save_results(class_frame_error_rates, total_frame_error_rate, 'results/')

Evaluating: 100%|██████████| 12/12 [00:04<00:00,  2.62batch/s]
