In [1]:
from IPython.display import clear_output
from IPython.display import Javascript
import pysnooper
import os
from sklearn.metrics import precision_score, recall_score, f1_score
import random
import pickle
from torch.utils.tensorboard import SummaryWriter
import argparse
from torch.utils import data
from torch.optim.lr_scheduler import ExponentialLR
from torch.nn import functional as F
from torch.distributions import kl_divergence, Normal
from torch import optim
import sys
import numpy as np
import torch
import statistics as stat
import pdb
%run MusicData.ipynb
%run model-CCA.ipynb
%run utils.ipynb

In [2]:
batch_size = 4
data_length = 64  # 48
need_listen = True
is_load = True
model_name = 'cca-simple'
learning_rate = 0.001
alpha = 0.1 # CCA
beta = 0.01 # KLD
gamma = 0.5 # CCA background
generation_mode = True

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

### Data Pre-processing


In [3]:
data_path = "/gpfsnyu/scratch/yz6492/multimodal/data/MusicData_full.pkl"
with open(data_path, 'rb') as f:
    all_data = pickle.load(f)
    random.Random(42).shuffle(all_data)
train_data = all_data[:int(0.9*len(all_data))]
valid_data = all_data[int(0.9*len(all_data)):] #int(0.9*len(all_data))]
# test_data  = all_data[int(0.9*len(all_data)):]

train_data = data_augment(train_data, data_length)
valid_data = data_augment(valid_data, data_length)
# test_data = data_augment(test_data, data_length)

# check for data/batch coherence
train_data = train_data[:len(train_data)//batch_size * batch_size]
valid_data = valid_data[:len(valid_data)//batch_size * batch_size]
# test_data = test_data[:len(test_data)//batch_size * batch_size]

print(f'data length: train: {len(train_data)},valid: {len(valid_data)}') #,test: {len(test_data)}, ')

train_data = train_data[:int(len(train_data)/batch_size)] * batch_size

random.shuffle(train_data)
random.shuffle(valid_data)

all_data = []

data length: train: 56556,valid: 6160


In [4]:
collect = [0,0,0]
for data_ in all_data:
    if data_.culture == 'Chinese':
        collect[0] += 1
    elif data_.culture == 'English':
        collect[1] += 1
    else:
        collect[2] += 1
print(collect)

[0, 0, 0]


In [5]:
# train_data = Dataset(train_data)
# train_loader = data.DataLoader(train_data, batch_size=1, shuffle=True)
# for i, labels in train_loader:
#     factors = [labels[0][0],labels[1][0],labels[2][0]]
#     for j in factors:
#         if torch.sum(j) < 1:
#             print(str(j))

In [6]:
culture_list = ['Chinese','English','Irish'] # 3
# key_list = ['C ','D ','D-','E ','E-','F ','G ','G-','A ','A-','B ','B-'] # 24 + 1 = 25
key_list = ['major','minor', 'key_others']
# meter_list = [2,3,4,6,8,0] # 6
meter_list = ['3','4', 'meter_others']

In [7]:
key_tensor = torch.zeros(len(key_list))
meter_tensor = torch.zeros(len(meter_list))
culture_tensor = torch.zeros(len(culture_list))
word2idx = dict()
for i in range(len(key_tensor)+len(meter_tensor)+len(culture_tensor)):
    if i < len(key_tensor):
        new_tensor = key_tensor.clone()
        new_tensor[i] = 1
        word2idx[key_list[np.argmax(new_tensor)]] = i
    elif i < (len(key_tensor) + len(meter_tensor)):
        new_tensor = meter_tensor.clone()
        new_tensor[i-len(key_tensor)] = 1
        word2idx[meter_list[np.argmax(new_tensor)]] = i
    else:
        new_tensor = culture_tensor.clone()
        new_tensor[i - (len(key_tensor) + len(meter_tensor))] = 1
        word2idx[culture_list[np.argmax(new_tensor)]] = i

In [8]:
word2idx

{'major': 0,
 'minor': 1,
 'key_others': 2,
 '3': 3,
 '4': 4,
 'meter_others': 5,
 'Chinese': 6,
 'English': 7,
 'Irish': 8}

In [9]:
class Dataset(data.Dataset):
    def __init__(self, data):
        
#         self.length = np.concatenate(all_data.numpy()).shape[0]//(batch_size*data_length)*batch_size*data_length
#         self.data = np.concatenate(all_data.numpy())[:length].reshape(-1, batch_size, data_length, 130)
    
        self.data = data
        
        self.culture_list = culture_list
        self.key_list = key_list
        self.meter_list = meter_list
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        X = self.data[index].torch_matrix[:data_length]
        culture = self.data[index].culture
        key = self.data[index].key
        meter = self.data[index].meter
        
        # tokenize
        culture_tensor = torch.zeros(len(self.culture_list))
        culture_index = self.culture_list.index(culture)
        culture_tensor[culture_index] = 1
        
        meter_tensor = torch.zeros(len(self.meter_list))
#         if int(meter[0]) in self.meter_list:
#             meter_index = self.meter_list.index(int(meter[0]))
#             meter_tensor[meter_index] = 1
        if int(meter[0]) % 3 == 0:
            meter_tensor[0] = 1
        elif int(meter[0]) % 2 == 0:
            meter_tensor[1] = 1
        else:
            meter_tensor[-1] = 1
        
        key_tensor = torch.zeros(len(self.key_list))
        # key_index = self.key_list.index(key[:2].upper())
        # print(key)
        if key.split()[-1] == 'major':
            key_tensor[0] = 1
        elif key.split()[-1] == 'minor':
            key_tensor[1] = 1
        else: # dorian and else
            key_tensor[2] = 1
            
#         print(key_tensor, culture_tensor, meter_tensor)
        
        
        return X, (culture_tensor, key_tensor, meter_tensor)

### Loss

In [10]:
def loss_function(recon, target_tensor, dis, 
                  language_index, z_matrix, 
                 alpha=1,
                 beta=1):
    '''
    提取出来的z_matrix要和所有的embedding做区分。所有特征加起来共有34种可能性，所以计算一个34*3的correlation。
    '''
    
    
    CE = F.nll_loss(
        recon.view(-1, recon.size(-1)),
        target_tensor,
        reduction='mean')
    normal = Normal(
        torch.zeros(dis.mean.size()).cuda(),
        torch.ones(dis.stddev.size()).cuda())
    KLD = kl_divergence(dis, normal).mean()
    
    (batch_size, n_features, n_dims) = z_matrix.size()
    # 
    embed_weight = model.word_embeds.weight
    corr_mask = np.zeros([batch_size, embed_weight.size(0), n_features]) # (bsz, 34, 3)
    for i, sample in enumerate(language_index.cpu().numpy()): # sample:(3), language_index: (bsz, 3 )
        for j, element in enumerate(sample): # element: index(1)
            corr_mask[i,element, j] = 1
    noncorr_mask = np.ones([batch_size, embed_weight.size(0), n_features]) - corr_mask
    embed_weight_stacked = embed_weight.repeat(batch_size, 1, 1)
    
    corr_mask, noncorr_mask = torch.from_numpy(corr_mask).cuda(), torch.from_numpy(noncorr_mask).cuda()
    
    z_matrix_norm = torch.norm(z_matrix, dim=-1, keepdim=True)
    embed_weight_stacked_norm = torch.norm(embed_weight_stacked, dim=-1, keepdim=True)
    
    corr = torch.bmm(embed_weight_stacked, z_matrix.permute(0,2,1)) # corr: (bsz, 34, 3)
    norm = torch.bmm(embed_weight_stacked_norm, z_matrix_norm.permute(0,2,1))
    
    res = corr / norm
    
    positive_result = res * corr_mask
    negative_result = torch.abs(res * noncorr_mask)
    
#     pdb.set_trace()
    
    CCA_loss = 3 + (-torch.sum(positive_result, dim = [1,2]) + gamma * torch.sum(negative_result, dim = [1,2])).mean()
    # CCA_loss = torch.sigmoid(CCA_loss_)

    # pdb.set_trace()
    return CE +  alpha * CCA_loss + beta * KLD, [CE, CCA_loss, KLD]

### train & validate

In [11]:
## train stats
loss_sum = list()
recon_sum = list()
cca_sum = list()
kld_sum = list()

best_loss = np.inf
history_loss_list = list()

In [12]:
def train(data_X, labels, step):
    global loss_sum, recon_sum, kld_sum, cca_sum, history_loss_list, best_loss
    
    input_tensor, label_tensor = data_X.clone(), data_X.clone()
    [culture_label, key_label, meter_label] = labels
    if torch.cuda.is_available():
        input_tensor, label_tensor = input_tensor.cuda(), label_tensor.cuda()
    optimizer.zero_grad()
    keywords = [[key_list[np.argmax(element)] for element in key_label], 
                [meter_list[np.argmax(element)] for element in meter_label], 
                [culture_list[np.argmax(element)] for element in culture_label]]
#     print(keywords)
    (recon, dis_mean, dis_stddev, language_index, z_matrix) = model(input_tensor, keywords)
    dis = Normal(dis_mean, dis_stddev)
#     pdb.set_trace()

    label_tensor = label_tensor.view(-1, label_tensor.size(-1)).max(-1)[1]
    key_label = key_label.view(-1, key_label.size(-1)).max(-1)[1]
    meter_label = meter_label.view(-1, meter_label.size(-1)).max(-1)[1]
    culture_label = culture_label.view(-1, culture_label.size(-1)).max(-1)[1]

    if torch.cuda.is_available():
        label_tensor, key_label, meter_label, culture_label = label_tensor.cuda(), key_label.cuda(), meter_label.cuda(), culture_label.cuda()

    loss, loss_elements = loss_function(recon, label_tensor, dis, 
              language_index, z_matrix,
             alpha=alpha, # ignore classifier
             beta=beta)
    loss.backward()
    
#     print(model.word_embeds.weight.grad)
#     print(model.music_embeds.weight.grad)

    loss_sum.append(loss.item()) 
    recon_sum.append(loss_elements[0].item())
    cca_sum.append(loss_elements[1].item())
    kld_sum.append(loss_elements[2].item())

    writer.add_scalars('Train Loss', {'Total Loss':loss.item()}, step)
#         for index, tag in enumerate(['Recon Loss','Key Loss','Meter Loss','Culture Loss', 'KLD']):
#             writer.add_scalars('Loss',{tag: loss_elements[index].item()}, step)

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()
    step += 1
    

    if step % 500 == 0:
        model.eval()
        valid_loader = data.DataLoader(valid_data, batch_size=len(valid_data), shuffle=True)
        with torch.no_grad():
            for valid_data_X, valid_label in valid_loader:
                valid_loss = validate(valid_data_X, valid_label)
#                 generate(valid_data_X, valid_label)
                break
#                 pdb.set_trace()
        model.train()
        print(f"batch {step}: Total {stat.mean(loss_sum):.5f}, Recon {stat.mean(recon_sum):.5f}, CCA {stat.mean(cca_sum):.5f}, KLD {stat.mean(kld_sum):.5f}," 
              f"Valid Loss {valid_loss:.5f}")
        writer.add_scalars('Valid Loss', {'Total Loss':valid_loss}, step)
        sys.stdout.flush()
        
        
            # early stop
        if valid_loss < best_loss:
            history_loss_list = list()
        else:
            history_loss_list.append(valid_loss)
        
        
        loss_sum = list()
        recon_sum = list()
        cca_sum = list()
        kld_sum = list()
        
        
        

    if step % 2500 == 0:
        scheduler.step()
    if step % 5000 == 0:
        model.eval()
        valid_loader = data.DataLoader(valid_data, batch_size=len(valid_data), shuffle=True)
        with torch.no_grad():
            for valid_data_X, valid_label in valid_loader:
                validate_f(valid_data_X, valid_label)
                break
        model.train()
#     if step % 5000 == 0:
#         for i in range(3):
#             idx = random.randint(0,data_X.size(0)-1)
#             print('label:')
#             numpy_to_midi(data_X[idx], output_folder = 'demo/train/', listen = need_listen)
#             print('generated:')
#             numpy_to_midi(recon[idx], output_folder = 'demo/train/', listen = need_listen)
    return step

In [13]:
def validate(data, labels, output = False):
#     for batch_idx, batch in enumerate(data):
    input_tensor, label_tensor = data, data
    [culture_label, key_label, meter_label] = labels
    if torch.cuda.is_available():
        input_tensor, label_tensor = input_tensor.cuda(), label_tensor.cuda()
    keywords = [[key_list[np.argmax(element)] for element in key_label], 
                [meter_list[np.argmax(element)] for element in meter_label], 
                [culture_list[np.argmax(element)] for element in culture_label]]
    (recon, dis_mean, dis_stddev, language_index, z_matrix) = model(input_tensor, keywords)
    dis = Normal(dis_mean, dis_stddev)
#         print(key_label)
    label_tensor = label_tensor.view(-1, label_tensor.size(-1)).max(-1)[1]
    if torch.cuda.is_available():
        label_tensor = label_tensor.cuda()
    loss, loss_elements = loss_function(recon, label_tensor, dis, 
              language_index, z_matrix,
             alpha=alpha, # ignore classifier
             beta=beta)
    
    # print(loss_elements[1])
#     for i in range(3):
#         idx = random.randint(0,data.size(0)-1)
#         print('label:')
#         numpy_to_midi(data[idx], output_folder = 'demo/valid/',listen = need_listen)
#         print('generated:')
#         numpy_to_midi(recon[idx], output_folder = 'demo/valid/', listen = need_listen)
    
    if output:
        print(f"Validation Loss: {loss_elements[0].item():.5f}, KLD {loss_elements[4].item():.5f}")
        sys.stdout.flush()
    return loss.item()

In [14]:
def validate_f(data, labels):
    # 计算每个label对应的embedding，直接做correlation。
    
#     for batch_idx, batch in enumerate(data):
    input_tensor, label_tensor = data, data
    [culture_label, key_label, meter_label] = labels
    if torch.cuda.is_available():
        input_tensor, label_tensor = input_tensor.cuda(), label_tensor.cuda()
    keywords = [[key_list[np.argmax(element)] for element in key_label], 
                [meter_list[np.argmax(element)] for element in meter_label], 
                [culture_list[np.argmax(element)] for element in culture_label]]
    (recon, dis_mean, dis_stddev, language_index, z_matrix) = model(input_tensor, keywords)
    
    key_embeds_matrix = model.word_embeds(torch.LongTensor([model.word2idx[str(j)]  for j in key_list]).cuda())
    meter_embeds_matrix = model.word_embeds(torch.LongTensor([model.word2idx[str(j)]  for j in meter_list]).cuda())
    culture_embeds_matrix = model.word_embeds(torch.LongTensor([model.word2idx[str(j)]  for j in culture_list]).cuda())
    
    key_label_b, culture_label_b, meter_label_b = \
        np.argmax(key_label.cpu().numpy(), axis = 1), \
        np.argmax(culture_label.cpu().numpy(), axis = 1),\
        np.argmax(meter_label.cpu().numpy(), axis = 1),

    z_key = z_matrix[:, 0, :].squeeze(1) # (batch size, emb size)
    key_result = np.zeros(input_tensor.size(0))
    for i, z_sample in enumerate(z_key): # z_sample: ([emb_size])
        z_key_norm = torch.norm(z_sample, dim=-1, keepdim=True) #[1]
        key_embeds_matrix_norm = torch.norm(key_embeds_matrix, dim=-1, keepdim=True) # [25]
        corr = key_embeds_matrix * z_sample # [25, 1]
        norm = key_embeds_matrix_norm * z_key_norm
        res = torch.sum(corr / norm, -1) # [25,]
        key_result[i] = np.argmax(res.cpu().numpy())
        
    z_meter = z_matrix[:, 1, :].squeeze(1) # (batch size, emb size)
    meter_result = np.zeros(input_tensor.size(0))
    for i, z_sample in enumerate(z_meter): # z_sample: ([emb_size])
        z_meter_norm = torch.norm(z_sample, dim=-1, keepdim=True) #[1]
        meter_embeds_matrix_norm = torch.norm(meter_embeds_matrix, dim=-1, keepdim=True) # [25]
        corr = meter_embeds_matrix * z_sample # [25, 1]
        norm = meter_embeds_matrix_norm * z_key_norm
        res = torch.sum(corr / norm, -1) # [25,]
        meter_result[i] = np.argmax(res.cpu().numpy())
        
    z_culture = z_matrix[:, 2, :].squeeze(1) # (batch size, emb size)
    culture_result = np.zeros(input_tensor.size(0))
    for i, z_sample in enumerate(z_culture): # z_sample: ([emb_size])
        z_culture_norm = torch.norm(z_sample, dim=-1, keepdim=True) #[1]
        culture_embeds_matrix_norm = torch.norm(culture_embeds_matrix, dim=-1, keepdim=True) # [25]
        corr = culture_embeds_matrix * z_sample # [25, 1]
        norm = culture_embeds_matrix_norm * z_key_norm
        res = torch.sum(corr / norm, -1) # [25,]
        culture_result[i] = np.argmax(res.cpu().numpy())


#     key_f = f1_score(key_label_b, key_result, average="macro")
#     culture_f = f1_score(culture_label_b, culture_result, average="macro")
#     meter_f = f1_score(meter_label_b, meter_result, average="macro")
    key_f = f1_score(key_label_b, key_result, average="micro")
    culture_f = f1_score(culture_label_b, culture_result, average="micro")
    meter_f = f1_score(meter_label_b, meter_result, average="micro")

    print(f"key f: {key_f}, meter f: {meter_f}, culture f: {culture_f}")

In [15]:
class MinExponentialLR(ExponentialLR):
    def __init__(self, optimizer, gamma, minimum, last_epoch=-1):
        self.min = minimum
        super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1)

    def get_lr(self):
        return [
            max(base_lr * self.gamma**self.last_epoch, self.min)
            for base_lr in self.base_lrs
        ]

In [16]:
def testing(samples):
    [culture_label, key_label, meter_label] = labels
    (recon, dis_mean, dis_stddev, y_key, y_meter, y_culture) = model(input_tensor)

In [17]:
# def inference()

### Generation

In [18]:
def generate(data_X, labels, culture = 2, meter = -1, key = -1, resample = True):
    model.eval()
    input_tensor, label_tensor = data_X, data_X
    [culture_label, key_label, meter_label] = labels
    
    keywords = [[key_list[np.argmax(element)] for element in key_label], 
            [meter_list[np.argmax(element)] for element in meter_label], 
            [culture_list[np.argmax(element)] for element in culture_label]]
    
    key_label, culture_label, meter_label = \
        np.argmax(key_label.numpy(), axis = 1), \
        np.argmax(culture_label.numpy(), axis = 1),\
        np.argmax(meter_label.numpy(), axis = 1), # 将label从one-hot改为index
    if torch.cuda.is_available():
        input_tensor, label_tensor = input_tensor.cuda(), label_tensor.cuda()
    dis, mu, var = model.encoder(input_tensor)
    

    
    if resample:
        z = dis.rsample()
    
    sample1_index = 0
    # print(sample1_index)
    z_others = z[sample1_index][model.z_key_dims + model.z_meter_dims + model.z_culture_dims:]
    z_key = z[sample1_index][:model.z_key_dims]
    z_culture = z[sample1_index][model.z_key_dims+model.z_meter_dims:model.z_key_dims+model.z_meter_dims+model.z_culture_dims]
    z_meter = z[sample1_index][model.z_key_dims:model.z_key_dims+model.z_meter_dims]
    
    
    
    z_key_avg_dict_collection = {i:list() for i in range(len(key_list))}
    z_meter_avg_dict_collection = {i:list() for i in range(len(meter_list))}
    z_culture_avg_dict_collection = {i:list() for i in range(len(culture_list))}
    
    for i, sample in enumerate(input_tensor): # 收集各个部分的z
        z_key_avg_dict_collection[key_label[i]].append((mu[i][:model.z_key_dims], 
                                                   var[i][:model.z_key_dims]))
        z_meter_avg_dict_collection[meter_label[i]].append((mu[i][model.z_key_dims:model.z_key_dims+model.z_meter_dims], 
                                                       var[i][model.z_key_dims:model.z_key_dims+model.z_meter_dims]))
        z_culture_avg_dict_collection[culture_label[i]].append((mu[i][model.z_key_dims+model.z_meter_dims:model.z_key_dims+model.z_meter_dims+model.z_culture_dims], 
                                                   var[i][model.z_key_dims+model.z_meter_dims:model.z_key_dims+model.z_meter_dims+model.z_culture_dims]))
        
    #对每一类z都计算一个平均的Normal   
    def calculate_normal(item_list, dim = model.feature_dims):
        if len(item_list)>0:
            return Normal(torch.mean(torch.stack([m[0] for m in item_list], dim = -1), dim=-1),torch.mean(torch.stack([m[1] for m in item_list], dim = -1), dim=-1))
        return Normal(torch.zeros(dim), torch.ones(dim))

    z_key_avg_dict_normal = {index:calculate_normal(item_list)
                             for index, item_list in z_key_avg_dict_collection.items()}
    z_meter_avg_dict_normal = {index:calculate_normal(item_list)
                               for index, item_list in z_meter_avg_dict_collection.items()}
    z_culture_avg_dict_normal = {index:calculate_normal(item_list)
                                 for index, item_list in z_culture_avg_dict_collection.items()}
    
    
    for option in  ['key','meter','culture']:
        if option == 'key':
            sample_z = torch.stack([torch.cat([z_key.rsample().cuda(), z_meter, z_culture, z_others], dim = -1) for i, z_key in z_key_avg_dict_normal.items()])
        if option == 'culture':
            sample_z = torch.stack([torch.cat([z_key, z_meter, z_culture.rsample(), z_others], dim = -1) for i, z_culture in z_culture_avg_dict_normal.items()])
        if option == 'meter':
            sample_z = torch.stack([torch.cat([z_key, z_meter.rsample(), z_culture, z_others], dim = -1) for i, z_meter in z_meter_avg_dict_normal.items()])

    #     sample_z = torch.cat([z_key, z_meter, z_culture, z_others], dim = -1)
    #     sample_z = sample_z.unsqueeze(0)

        recon = model.decoder(sample_z)
    #     recon = recon[0]

        print("origin:")
        print(keywords[0][sample1_index], keywords[1][sample1_index], keywords[2][sample1_index])
        numpy_to_midi(data_X[sample1_index], listen = need_listen)
        print(f"{option}_transfer:")
        for i in range(recon.size(0)):
            numpy_to_midi(recon[i], listen = need_listen)
    

#         print(key_label)


In [19]:
writer = SummaryWriter()

model = DisentangledVAE(roll_dims = 130,
             hidden_dims = 512,
             embed_dims = 32,
             feature_dims = 32,
             z_other_dims = 160, 
             word2idx = word2idx,
             n_step = data_length)
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
scheduler = MinExponentialLR(optimizer, gamma=0.99, minimum=1e-4)
# step = 0
if torch.cuda.is_available():
    print('Using: ',
          torch.cuda.get_device_name(torch.cuda.current_device()))

if is_load:
    model.load_state_dict(torch.load("/gpfsnyu/scratch/yz6492/multimodal/model/cca-simple_epoch15_length64.pt"))

if torch.cuda.is_available():
    model.cuda()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Using:  GeForce RTX 2080 Ti


### Training Code

In [None]:
# model training
step = 0
model.train()

params = list(model.parameters())
# print(params)
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Model total parameters:', total_params)
sys.stdout.flush()



train_data = Dataset(train_data)
valid_data = Dataset(valid_data)
# test_data = Dataset(test_data)

for epoch in range(1, 100):
    
    if generation_mode:
        model.eval()
        valid_loader = data.DataLoader(valid_data, batch_size=len(valid_data), shuffle=True)
        with torch.no_grad():
            for valid_data_X, valid_label in valid_loader:
                valid_loss = validate(valid_data_X, valid_label)
                generate(valid_data_X, valid_label)
                pdb.set_trace()

            model.train()
    else:      
        print('Epoch: {}'.format(epoch))

        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

        if (epoch > 0) and (epoch % 2) == 0:
            clear_output(wait=True)

    #     # test_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
        for i, labels in train_loader:
            i = i.float()
            labels = [j.float() for j in labels]
            step = train(i, labels, step)

            if len(history_loss_list) > 5:
                torch.save(model.cpu().state_dict(), f'/gpfsnyu/scratch/yz6492/multimodal/model/{model_name}_BEST_length{data_length}.pt')
                model.cuda()


        if epoch % 5 == 0:
            torch.save(model.cpu().state_dict(), f'/gpfsnyu/scratch/yz6492/multimodal/model/{model_name}_epoch{epoch}_length{data_length}.pt')
            model.cuda()

# torch.save(model.cpu().state_dict(), f'/gpfsnyu/scratch/yz6492/multimodal/model/{model_name}_final.pt')

# test

writer.close()

Model total parameters: 5661122
origin:
major 4 Irish


key_transfer:


origin:
major 4 Irish


meter_transfer:


origin:
major 4 Irish


culture_transfer:


> <ipython-input-20-a14d00f57dab>(23)<module>()
-> for valid_data_X, valid_label in valid_loader:
(Pdb) c
origin:
major 4 Irish


key_transfer:


origin:
major 4 Irish


meter_transfer:


origin:
major 4 Irish


culture_transfer:


> <ipython-input-20-a14d00f57dab>(23)<module>()
-> for valid_data_X, valid_label in valid_loader:


### Generation

In [None]:


# train_data = Dataset(train_data)
# train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=False)
# for i, labels in train_loader:
#     i = i.float()
#     labels = [j.float() for j in labels]
#     generate(i, labels)
#     break