In [2]:
import sys; sys.path.append('..')
import torch
from torch.utils.data import random_split
import pandas as pd
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

print("In Colab: {}".format(IN_COLAB))

if IN_COLAB:
  !pip install torchmetrics
  !pip install kornia
  !pip install torchvision
  google.colab.drive.mount('/content/drive')
  %cd /content/drive/My Drive/Go-Viral-Project/notebooks


# Using CNN for feature extraction from audio input

In [5]:
AUDIO_PATH = '../data/audio'

TENSOR_PATH = '../data/specs'

METADATA_PATH = '../data/metadata.csv'

SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED)

<torch._C.Generator at 0x78952023a970>

In [6]:
# clean up df - remove songs that dont have spectrograms
import os

df = pd.read_csv('../data/metadata.csv')
data_path = '../data/specs/'
files_not_found = 0

for idx, row in df.iterrows():
    song_path = os.path.join(data_path, row['id'] + '.pt')

    if not os.path.exists(song_path):
        # print(f"File not found: {song_path}")
        df = df.drop(idx)
        files_not_found += 1

print(f"Number of files not found: {files_not_found}")

Number of files not found: 0


In [7]:
from torchvision import transforms

def preprocess_mbnet(X):
  prep = transforms.Compose([
      transforms.Resize((224,2206)),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

  return prep(torch.cat((X,X,X),dim=0))

In [8]:
from src.RNN_utils.dataset import SoundDS
from torch.utils.data import default_collate

myds = SoundDS(df, '../data/specs/', preprocess_mbnet)

# Random split of 80:20 between training and validation
num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])

to_gpu = lambda x: list(map(lambda t: t.to(device), default_collate(x)))
# Create training and validation data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, collate_fn = to_gpu, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, collate_fn = to_gpu, shuffle=False)

In [9]:
b_size, channels, hight, width = next(iter(train_dl))[0].shape
num_batches = len(train_dl)
print(f'num batches: {num_batches}\nbatch size: {b_size}\nchannels: {channels}\nhight: {hight} \nwidth: {width}')



num batches: 183
batch size: 16
channels: 3
hight: 224 
width: 2206


## MobileNet V2:

### Overfitting the model:

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.classifier = nn.Sequential(
    nn.Linear(1280,2),
    nn.Softmax(dim=1)
)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)

epochs = 20

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [None]:
print(model)

In [None]:
sum(param.numel() for param in model.parameters())

2226434

In [None]:
(X,y) = next(iter(train_dl))

In [None]:
for epoch in range(epochs):
    optimizer.zero_grad()
    y_prob = model(X)
    loss = criterion(y_prob,y)
    loss.backward()
    optimizer.step()
    loss = loss.item()
    scheduler.step()
    acc = torch.sum(torch.argmax(y_prob,dim=1)==y).item()/b_size
    print(f'Epoch #{epoch}: Loss - {loss}, Accuracy - {acc}')

### Training the model:

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.classifier = nn.Sequential(
    nn.Linear(1280,2),
    nn.Softmax(dim=1)
)

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.2)

In [None]:
from src.RNN_utils.trainer import trainer
from src.RNN_utils.cross_val import crossValidate

train_model = trainer(model,criterion,optimizer,scheduler,device)

In [None]:
results = train_model.train(train_dl,20,True)

### Cross validation:

In [10]:
configs = []
results = []

In [11]:
from src.RNN_utils.trainer import trainer
from src.RNN_utils.cross_val import crossValidate
from kornia.losses.focal import BinaryFocalLossWithLogits

config = {'lr':1e-3, 'weight_decay':1e-4, 'step_size': 10, 'gamma': 0.1}

model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.classifier = nn.Sequential(
    nn.Linear(1280,2),
    nn.Softmax(dim=1)
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=config['lr'], weight_decay=config['weight_decay'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=config['step_size'],gamma=config['gamma'])

train_model = trainer(model,criterion,optimizer,scheduler,device)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 122MB/s]


In [12]:
cv_obj = crossValidate(train_ds=train_ds, device=device, folds=3, batch_size=16)

In [None]:
results.append(cv_obj.runCV(train_model, epochs=20))
configs.append(config)

Fold #0:


Train Batch: 100%|██████████| 122/122 [32:05<00:00, 15.78s/it]


Epoch #0: Loss - 83.9294042289257, Accuracy - 0.6132852729145211


Test Batch: 100%|██████████| 61/61 [15:26<00:00, 15.19s/it]


Val results: Loss - 42.07582414150238, Accuracy - 0.5534979423868313


Train Batch: 100%|██████████| 122/122 [01:56<00:00,  1.05it/s]


Epoch #1: Loss - 82.60667553544044, Accuracy - 0.6194644696189495


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.72it/s]


Val results: Loss - 39.33746412396431, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #2: Loss - 82.16661912202835, Accuracy - 0.6179196704428425


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.65it/s]


Val results: Loss - 40.64802020788193, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:58<00:00,  1.03it/s]


Epoch #3: Loss - 81.70745900273323, Accuracy - 0.6204943357363543


Test Batch: 100%|██████████| 61/61 [00:34<00:00,  1.77it/s]


Val results: Loss - 39.73259061574936, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:58<00:00,  1.03it/s]


Epoch #4: Loss - 81.71707057952881, Accuracy - 0.6220391349124614


Test Batch: 100%|██████████| 61/61 [00:34<00:00,  1.78it/s]


Val results: Loss - 39.46547266840935, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:58<00:00,  1.03it/s]


Epoch #5: Loss - 81.86672246456146, Accuracy - 0.6210092687950567


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.68it/s]


Val results: Loss - 39.70036005973816, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #6: Loss - 81.11225920915604, Accuracy - 0.6220391349124614


Test Batch: 100%|██████████| 61/61 [00:34<00:00,  1.79it/s]


Val results: Loss - 39.813989102840424, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:58<00:00,  1.03it/s]


Epoch #7: Loss - 81.16790056228638, Accuracy - 0.621524201853759


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.73it/s]


Val results: Loss - 39.81429550051689, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #8: Loss - 81.13158071041107, Accuracy - 0.6220391349124614


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.67it/s]


Val results: Loss - 39.65514951944351, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:57<00:00,  1.03it/s]


Epoch #9: Loss - 80.91155672073364, Accuracy - 0.6210092687950567


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.74it/s]


Val results: Loss - 39.53479582071304, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:57<00:00,  1.04it/s]


Epoch #10: Loss - 81.37294042110443, Accuracy - 0.6220391349124614


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.71it/s]


Val results: Loss - 39.629962891340256, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.02it/s]


Epoch #11: Loss - 80.4545875787735, Accuracy - 0.6204943357363543


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.67it/s]


Val results: Loss - 39.533889412879944, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.03it/s]


Epoch #12: Loss - 80.85516494512558, Accuracy - 0.6220391349124614


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.70it/s]


Val results: Loss - 39.53888702392578, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #13: Loss - 80.26466810703278, Accuracy - 0.6204943357363543


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.71it/s]


Val results: Loss - 39.544360995292664, Accuracy - 0.6440329218106996


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #14: Loss - 80.04754480719566, Accuracy - 0.621524201853759


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.69it/s]


Val results: Loss - 39.440973073244095, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #15: Loss - 80.43023785948753, Accuracy - 0.6204943357363543


Test Batch: 100%|██████████| 61/61 [00:37<00:00,  1.65it/s]


Val results: Loss - 39.58375430107117, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #16: Loss - 80.24688351154327, Accuracy - 0.6184346035015448


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.67it/s]


Val results: Loss - 39.47732424736023, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #17: Loss - 80.17758166790009, Accuracy - 0.6261585993820803


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.65it/s]


Val results: Loss - 39.35185843706131, Accuracy - 0.6388888888888888


Train Batch: 100%|██████████| 122/122 [02:02<00:00,  1.00s/it]


Epoch #18: Loss - 80.1347608268261, Accuracy - 0.6143151390319258


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.65it/s]


Val results: Loss - 39.39938676357269, Accuracy - 0.6450617283950617


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #19: Loss - 79.60122495889664, Accuracy - 0.6266735324407827


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.66it/s]


Val results: Loss - 39.30166247487068, Accuracy - 0.6430041152263375
Fold #1:


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #0: Loss - 82.40230923891068, Accuracy - 0.6320123520329387


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.67it/s]


Val results: Loss - 42.68864643573761, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #1: Loss - 81.98349544405937, Accuracy - 0.6366443643849717


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.71it/s]


Val results: Loss - 41.47284513711929, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #2: Loss - 81.02854138612747, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.65it/s]


Val results: Loss - 41.7760391831398, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #3: Loss - 80.75624322891235, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.71it/s]


Val results: Loss - 41.28472048044205, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #4: Loss - 80.78207218647003, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:34<00:00,  1.75it/s]


Val results: Loss - 41.54714250564575, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #5: Loss - 79.99825268983841, Accuracy - 0.6345856922285126


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.73it/s]


Val results: Loss - 41.60062322020531, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #6: Loss - 81.51609072089195, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.72it/s]


Val results: Loss - 41.46882206201553, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:58<00:00,  1.03it/s]


Epoch #7: Loss - 80.29541826248169, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.70it/s]


Val results: Loss - 40.93819862604141, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #8: Loss - 80.01593297719955, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:34<00:00,  1.74it/s]


Val results: Loss - 41.01846778392792, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #9: Loss - 80.2685034275055, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.68it/s]


Val results: Loss - 41.12377279996872, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [01:59<00:00,  1.02it/s]


Epoch #10: Loss - 79.59914514422417, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.73it/s]


Val results: Loss - 40.92094510793686, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #11: Loss - 79.50860899686813, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:35<00:00,  1.71it/s]


Val results: Loss - 40.701095163822174, Accuracy - 0.6127703398558187


Train Batch: 100%|██████████| 122/122 [02:00<00:00,  1.01it/s]


Epoch #12: Loss - 79.47807559370995, Accuracy - 0.638188368502316


Test Batch: 100%|██████████| 61/61 [00:36<00:00,  1.67it/s]


Val results: Loss - 40.88321179151535, Accuracy - 0.6127703398558187


Train Batch:  74%|███████▍  | 90/122 [01:29<00:31,  1.02it/s]

In [None]:
from src.RNN_utils.cross_val import plotCV

plotCV(results, configs,title='Cross Validation for CNN MobileNetV2 with Cross-Entropy Loss')

In [None]:
print(results)