In [3]:
# %%
# %%
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# from torchviz import make_dot
import os, fnmatch
import torchaudio
import sounddevice as sd
import soundfile as sf

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import array
import torch.fft as fft
from CustomDataloader import CustomDataloaderCreator,DataConfig
import tqdm
from collections import OrderedDict
import wandb
from plottingHelper import compareTwoAudios

class Trainer():

    def __init__(self,model,optimizer,loss_func,num_epochs):
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.num_epochs = num_epochs

        self.best_vals = {'psnr': 0.0, 'loss': 1e8}
        self.best_model = OrderedDict((k, v.detach().clone()) for k, v in self.model.state_dict().items())

    def train(self,train_dataloader,val_dataloader,dataConfig,modelSaveDir,wandbName,debugFlag = False,useWandB = True):
        
        name = wandbName
        #Initializing wandb  

        if(useWandB):
            wandb.init(
                # set the wandb project where this run will be logged
                project="Shobhit_SEM9",
                name= name,
                config={
                    "epochs": self.num_epochs,
                    "learning_rate": dataConfig.learningRate,
                    "batch_size": dataConfig.batchSize,
                    "stride_length": dataConfig.stride_length,
                    "frame_size": dataConfig.frameSize,
                    "sample_rate": dataConfig.sample_rate,
                    "duration": dataConfig.duration,
                    "n_fft": dataConfig.n_fft,
                    "modelBufferFrames": dataConfig.modelBufferFrames,
                    "shuffle": dataConfig.shuffle,
                    "dtype": dataConfig.dtype,
                },
            )

        modelPath = f'modelSaveDir/{name}'
        fft_freq_bins = int(dataConfig.n_fft/2) + 1
        
        #Start training loop
        with tqdm.trange(self.num_epochs, ncols=100) as t:
            for i in t:
                # <Inside an epoch>    
                #Make sure gradient tracking is on, and do a pass over the data
                self.model.train(True)

                # Update model
                self.optimizer.zero_grad()

                running_trainloss = 0.0
                #Training loop
                randomSelectedBatchNum = np.random.randint(0,len(train_dataloader))
                

                for batchNum,data in enumerate(train_dataloader):
                    # <Inside a batch>  
                    modelInputs, targets = data
                    randomSelectedTrainingPoint = np.random.randint(0,targets.shape[0])
                    # print(f'modelInputs.dtype = {modelInputs.dtype}')

                    #ModelInputs here is of type complex64
                    if(dataConfig.dtype == torch.float32):
                        modelInputs = torch.abs(modelInputs).float()
                    else:
                        modelInputs = torch.abs(modelInputs).double()
                   
                    # print(f'modelInputs.dtype = {modelInputs.dtype}')
                    #Idk if this is required now
                    if(batchNum == len(train_dataloader)):
                        break

                    reshaped_input = modelInputs.view(modelInputs.shape[0], fft_freq_bins*dataConfig.modelBufferFrames)
                    #Model input is (Batchsize, 257*10) :: batch of 10 frames of 257 FFT bins
                    ifftedOutputs = self.model(reshaped_input)

                    #Model output is (Batchsize, 512) :: batch of single IFFT-ed frame of 257 FFT bins
                    if(debugFlag): 
                        print(f'ifftedOutputs.shape = {ifftedOutputs.shape}')     

                    #Taking the first 32 samples from the ifft output
                    # firstSamples = ifftedOutputs[:,:dataConfig.stride_length] 
                    firstSamples = ifftedOutputs

                    if(debugFlag):
                        print(f'IFFT of model output shape = {ifftedOutputs.shape}')
                        print(f'IFFT of model output first {dataConfig.stride_length} samples shape = {firstSamples.shape}')   

                    loss = self.loss_func(firstSamples, targets)


                    if(batchNum == randomSelectedBatchNum and i%10 ==0):
                        compareTwoAudios(firstSamples[randomSelectedTrainingPoint],targets[randomSelectedTrainingPoint],i,randomSelectedBatchNum,logInWandb = useWandB)
                        # printQualityScores(targets[5],firstSamples[5],dataConfig.sample_rate)
                    
                    running_trainloss += loss
                    loss.backward()
                    self.optimizer.step()

                # <After an epoch> 
                avg_trainloss = running_trainloss / len(train_dataloader)
                # Check for validation loss!
                running_vloss = 0.0
                # Set the model to evaluation mode
                self.model.eval()

                # Disable gradient computation and reduce memory consumption.
                with torch.no_grad():

                    for i,data in enumerate(val_dataloader):
                        
                        val_modelInputs, val_targets = data

                         #val_modelInputs here is of type complex64
                        if(dataConfig.dtype == torch.float32):
                            val_modelInputs = torch.abs(val_modelInputs).float()
                        else:
                            val_modelInputs = torch.abs(val_modelInputs).double()
                   

                        #Idk if this is required now
                        if(i == len(val_dataloader)):
                            break

                        val_reshaped_input = val_modelInputs.view(val_modelInputs.shape[0], fft_freq_bins*dataConfig.modelBufferFrames)
                        #Model input is (Batchsize, 257*10) :: batch of 10 frames of 257 FFT bins
                        val_ifftedOutputs = self.model(val_reshaped_input)

                        #Model output is (Batchsize, 512) :: batch of single IFFT-ed frame of 257 FFT bins
                        if(debugFlag): 
                            print(f'ifftedOutputs.shape = {val_ifftedOutputs.shape}')     

                        #Taking the first 32 samples from the ifft output
                        # firstSamples = ifftedOutputs[:,:dataConfig.stride_length] 
                        val_firstSamples = val_ifftedOutputs

                        if(debugFlag):
                            print(f'IFFT of model output shape = {val_ifftedOutputs.shape}')
                            print(f'IFFT of model output first {dataConfig.stride_length} samples shape = {val_firstSamples.shape}')   

                        val_loss = self.loss_func(val_firstSamples, val_targets)

                        # if(i == len(val_dataloader)/2):
                            # compareTwoAudios(val_firstSamples[5],val_targets[5])
                        #     # printQualityScores(val_targets[5],val_firstSamples[5],dataConfig.sample_rate)

                        running_vloss += val_loss

                # Calculate average val loss
                avg_vloss = running_vloss / len(val_dataloader)
                print('LOSS train {} valid {}'.format(avg_trainloss, avg_vloss))


                if(useWandB):
                    # Log results to W&B
                    wandb.log({
                        'trainLoss': avg_trainloss,
                        'valLoss': avg_vloss,
                    })
                
                #Save the model if the validation loss is good
                if avg_vloss < self.best_vals['loss']:
                    self.best_vals['loss'] = avg_vloss
                    self.best_model = OrderedDict((k, v.detach().clone()) for k, v in self.model.state_dict().items())
                    torch.save(self.best_model, f'{wandbName}.pt')

        if(useWandB):
            wandb.finish() 
