### prereqs

In [None]:
! git clone https://github.com/marios1861/hosa-voice.git
%cd hosa-voice/v2
! pip install -r requirements.txt

### Lets first import and change the torchvision models' to output 4 classes instead of 1000

In [1]:
from torchvision.models import vit_b_16, ViT_B_16_Weights, mobilenet_v3_large, MobileNet_V3_Large_Weights

### Choose which of the two models to use

In [None]:
from torch import nn
model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
model.heads = nn.Linear(model.heads[0].in_features, 4)
# model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
# model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 4)

### Choose training lr and batch size

In [None]:
lr = 0.001
batch_size = 32

### Initialize lightning module and data module

In [None]:
import os
os.chdir('..')

In [None]:
import sys
sys.path.append("v2")

In [None]:
import os
from voice_classifier import VoiceClassifier
from voice_datasets import DatasetModule

pl_module = VoiceClassifier(model, lr)
pl_data = DatasetModule(batch_size, data_dir='datasets/bicepstrum_image/bicepstrum_ml_normalized_imagesc_100_100', num_workers=2)

### Initialize the trainer

In [None]:
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch.callbacks import RichProgressBar
torch.set_float32_matmul_precision("medium")

epochs = 8
log_name = "transformer"
# log_name = "mobilenet"


# torch._dynamo.config.verbose=True
trainer = Trainer(
    # precision="16-mixed",
    # gradient_clip_algorithm="norm",
    max_epochs=epochs,
    log_every_n_steps=1,
    callbacks=[RichProgressBar()]
)

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
%tensorboard --logdir logs/tb_logs

In [None]:
os.chdir("..")

In [None]:
trainer.logger = TensorBoardLogger("logs", name=log_name, sub_dir="tb_logs")
trainer.fit(pl_module, pl_data)

In [None]:
trainer.logger = CSVLogger("logs", name=log_name)
trainer.test(pl_module, pl_data)