In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import load_csv_timbre
import numpy as np
from tqdm import tqdm
import datetime

In [2]:
CUDA_LAUNCH_BLOCKING = 1.

device = (
   "cuda" if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device)
# device = "cpu"

BATCH_SIZE = 4096
train_loader, test_loader, validation_loader = load_csv_timbre.load_MSD(batch_size=BATCH_SIZE, device=device)

print("Data Loaded")

cuda
Data Loaded


In [4]:
import time
# class model(nn.Module):
#     def __init__(self):
#         super(model, self).__init__()

#         self.fc1 = nn.Linear(128 * 13, 512) # changed size
#         self.fc2 = nn.Linear(512, 128)
#         self.fc3 = nn.Linear(128, 12)

#     def forward(self, x):
#         # reshape the input to be (batchsize, 128*3)
#         x = x.view(x.size(0), -1)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()

        # look at 4 mel bands above and below, and look at 1 time sample before and after
        self.conv1 = nn.Conv2d(1, 4, kernel_size=(9, 3), stride=(1, 1))
        self.conv2 = nn.Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(8496, 1024)
        self.fc2 = nn.Linear(1024, 12)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc2(x)
        return x
    
# class model(nn.Module):
#     def __init__(self):
#         super(model, self).__init__()

#         # look at 4 mel bands above and below, and look at 1 time sample before and after
#         self.conv1 = nn.Conv2d(1, 8, kernel_size=(9, 3), stride=(1, 1))
#         self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
#         self.dropout = nn.Dropout(0.4)
#         self.fc1 = nn.Linear(16992, 4096)
#         self.fc2 = nn.Linear(4096, 1024)
#         self.fc3 = nn.Linear(1024, 12)

#     def forward(self, x):
#         x = x.unsqueeze(1)
#         x = F.relu(self.conv1(x))
#         x = F.relu(self.conv2(x))
#         x = self.dropout(x)
#         x = x.view(x.size(0), -1)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

        
model = model().to(device)

saved_model = torch.load('small_specgram_timbre_model_epoch_55_best.pth')
model.load_state_dict(saved_model['model_state_dict'])

print(model)

last_epoch = 55

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
optimizer.load_state_dict(saved_model['optimizer_state_dict'])
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 15, gamma = 0.1)
# scheduler.load_state_dict(saved_model['scheduler_state_dict'])

epochs = 80
min_valid_loss = 335.4599358694894

model(
  (conv1): Conv2d(1, 4, kernel_size=(9, 3), stride=(1, 1))
  (conv2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
  (dropout): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=8496, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=12, bias=True)
)


In [5]:
#start train loop
for e in range(last_epoch, epochs):
    start = time.time()
    start_time = datetime.datetime.now()
    model.train()
    train_loss = 0.0
    validation_loss = 0.0

    # actual model training
    for i, (data,labels) in tqdm(enumerate(train_loader)):
        data, labels = data.to(device), labels.to(device)
         
        # Clear the gradients
        optimizer.zero_grad()
        # Forward Pass
        target = model(data.float())
        # Find the Loss
        loss = criterion(target,labels.float())
        # Calculate gradients
        loss.backward()
        # Update Weights
        optimizer.step()
        # Calculate Loss
        train_loss += loss.item()

    train_loss = train_loss / (i + 1)


    model.eval()
    # testing valiation loss
    for i, (data,labels) in tqdm(enumerate(validation_loader)):
        data, labels = data.to(device), labels.to(device)
         
        # Forward Pass
        target = model(data.float())
        # Find the validation loss
        loss = criterion(target,labels.float())

        # Calculate Loss
        validation_loss += loss.item()

    validation_loss = validation_loss / (i + 1)


    print('----------------------------------------------------------')    
    if validation_loss < min_valid_loss:
        print(f'validation loss decreased: {min_valid_loss} -> {validation_loss}')
        print(f'saving current model as big_specgram_timbre_model_epoch_{e + 1}_best.pth')
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                    # 'scheduler_state_dict': scheduler.state_dict(),
                    }, f'small_specgram_timbre_model_epoch_{e + 1}_best.pth')
        min_valid_loss = validation_loss


    end = time.time()
    end_time = datetime.datetime.now()
    print(f"Epoch: {e+1}/{epochs}")
    print(f'Start @ {start_time.hour}:{start_time.minute}, End @ {end_time.hour}:{end_time.minute}')
    print(f"Epoch Duration: {end-start:.2f}s / {(end-start) / 60:.2f} min")
    print(f"Training Loss: {train_loss:.6f}")
    print(f"Validation Loss: {validation_loss:.6f}")
    print('Learning rate: [0.005]')
    # print(f'Learning rate: {scheduler.get_last_lr()}')
    print('----------------------------------------------------------\n')
    torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                    # 'scheduler_state_dict': scheduler.state_dict(),
                    }, f'small_specgram_timbre_model_epoch_{e + 1}_current.pth')
    # scheduler.step()

205it [35:47, 10.48s/it]
42it [03:20,  4.76s/it]


----------------------------------------------------------
Epoch: 56/80
Start @ 0:11, End @ 0:50
Epoch Duration: 2348.54s / 39.14 min
Training Loss: 464.542144
Validation Loss: 336.212185
Learning rate: [0.005]
----------------------------------------------------------



205it [39:36, 11.59s/it]
42it [03:21,  4.79s/it]


----------------------------------------------------------
Epoch: 57/80
Start @ 0:50, End @ 1:33
Epoch Duration: 2578.20s / 42.97 min
Training Loss: 456.479237
Validation Loss: 335.792884
Learning rate: [0.005]
----------------------------------------------------------



205it [39:28, 11.55s/it]
42it [03:19,  4.75s/it]


----------------------------------------------------------
validation loss decreased: 335.4599358694894 -> 335.3856809706915
saving current model as big_specgram_timbre_model_epoch_58_best.pth
Epoch: 58/80
Start @ 1:33, End @ 2:16
Epoch Duration: 2568.27s / 42.80 min
Training Loss: 455.655438
Validation Loss: 335.385681
Learning rate: [0.005]
----------------------------------------------------------



205it [39:51, 11.67s/it]
42it [03:17,  4.71s/it]


----------------------------------------------------------
Epoch: 59/80
Start @ 2:16, End @ 2:59
Epoch Duration: 2589.90s / 43.16 min
Training Loss: 449.913362
Validation Loss: 336.231813
Learning rate: [0.005]
----------------------------------------------------------



205it [39:14, 11.48s/it]
42it [03:19,  4.74s/it]


----------------------------------------------------------
validation loss decreased: 335.3856809706915 -> 334.6243776593889
saving current model as big_specgram_timbre_model_epoch_60_best.pth
Epoch: 60/80
Start @ 2:59, End @ 3:41
Epoch Duration: 2553.57s / 42.56 min
Training Loss: 460.495707
Validation Loss: 334.624378
Learning rate: [0.005]
----------------------------------------------------------



205it [39:25, 11.54s/it]
42it [03:18,  4.73s/it]


----------------------------------------------------------
Epoch: 61/80
Start @ 3:41, End @ 4:24
Epoch Duration: 2563.99s / 42.73 min
Training Loss: 464.641538
Validation Loss: 335.218439
Learning rate: [0.005]
----------------------------------------------------------



205it [39:49, 11.66s/it]
42it [03:19,  4.75s/it]


----------------------------------------------------------
Epoch: 62/80
Start @ 4:24, End @ 5:7
Epoch Duration: 2589.45s / 43.16 min
Training Loss: 456.778566
Validation Loss: 335.259866
Learning rate: [0.005]
----------------------------------------------------------



205it [39:03, 11.43s/it]
42it [03:18,  4.72s/it]


----------------------------------------------------------
Epoch: 63/80
Start @ 5:7, End @ 5:50
Epoch Duration: 2541.96s / 42.37 min
Training Loss: 454.500990
Validation Loss: 334.948167
Learning rate: [0.005]
----------------------------------------------------------



205it [39:07, 11.45s/it]
42it [03:19,  4.75s/it]


----------------------------------------------------------
Epoch: 64/80
Start @ 5:50, End @ 6:32
Epoch Duration: 2546.63s / 42.44 min
Training Loss: 459.385780
Validation Loss: 334.862365
Learning rate: [0.005]
----------------------------------------------------------



205it [39:34, 11.58s/it]
42it [03:18,  4.73s/it]


----------------------------------------------------------
Epoch: 65/80
Start @ 6:32, End @ 7:15
Epoch Duration: 2573.50s / 42.89 min
Training Loss: 454.092628
Validation Loss: 334.734373
Learning rate: [0.005]
----------------------------------------------------------



205it [38:55, 11.39s/it]
42it [03:05,  4.42s/it]


----------------------------------------------------------
Epoch: 66/80
Start @ 7:15, End @ 7:57
Epoch Duration: 2521.74s / 42.03 min
Training Loss: 453.092661
Validation Loss: 334.628201
Learning rate: [0.005]
----------------------------------------------------------



124it [22:45, 11.01s/it]


KeyboardInterrupt: 