# Initialize

In [14]:
#@title Import {display-mode: "form"}
import math
import torch
from torch import nn

# Model

In [15]:
#@title Swish {display-mode: "form"}
class Swish(nn.Module):
  def forward(self, x):
    return x * torch.sigmoid(x)

In [16]:
#@title FeatureEx1d {display-mode: "form"}
class FeatureEx1d(nn.Module):
  def __init__(self, input_size):
    super().__init__()
    self.layers = nn.Sequential(
        nn.linear(input_size, 256),
        nn.functional.relu(),
        nn.linear(256, 128),
        nn.functional.relu(),
        nn.linear(128, 64),
        nn.functional.relu()
    )

  def forward(self, x):
    return self.layers(x)

In [17]:
#@title SEblock {display-mode: "form"}
class SEblock(nn.Module):
  def __init__(self, ch_in, ch_sq):
    super().__init__()
    self.se = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(ch_in, ch_sq, 1),
        Swish(),
        nn.Conv2d(ch_sq, ch_in, 1),
    )
    self.se.apply(weight_init)

  def forward(self, x):
    return x * torch.sigmoid(self.se(x))

def weight_init(m):
  if isinstance(m, nn.Conv2d):
    nn.init.kaiming_normal_(m.weight)

  if isinstance(m, nn.Linear):
    nn.init.kaiming_uniform_(m.weight)
    nn.init.zeros_(m.bias)

In [18]:
#@title ConvBNFirst {display-mode: "form"}
class ConvBNFirst(nn.Module):
  def __init__(self, ch_in, ch_out, kernel_size=(3, 257), stride=(1, 257), padding=0, groups=1):
    super().__init__()
    self.layers=nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size, stride, padding, groups=groups, bias=False),
        nn.BatchNorm2d(ch_out),
    )
    self.layers.apply(weight_init)

  def forward(self, x):
    return self.layers(x)

In [19]:
#@title ConvBN {display-mode: "form"}
class ConvBN(nn.Module):
  def __init__(self, ch_in, ch_out, kernel_size, stride=1, padding=0, groups=1):
    super().__init__()
    self.layers=nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size, stride, padding, groups=groups, bias=False),
        nn.BatchNorm2d(ch_out),
    )
    self.layers.apply(weight_init)

  def forward(self, x):
    return self.layers(x)

In [20]:
#@title DropConnect {display-mode: "form"}
class DropConnect(nn.Module):
  def __init__(self, drop_rate):
    super().__init__()
    self.drop_rate=drop_rate

  def forward(self, x):
    if self.training:
      keep_rate=1.0-self.drop_rate
      r = torch.rand([x.size(0),1,1,1], dtype=x.dtype).to(x.device)
      r+=keep_rate
      mask=r.floor()
      return x.div(keep_rate)*mask
    else:
      return x

In [21]:
#@title BMConvBlock {display-mode: "form"}
class BMConvBlock(nn.Module):
  def __init__(self,ch_in,ch_out,expand_ratio,stride,kernel_size,reduction_ratio=4,drop_connect_rate=0.2):
    super().__init__()
    self.use_residual = (ch_in==ch_out) & (stride==1)
    ch_med = int(ch_in*expand_ratio)
    ch_sq  = max(1, ch_in//reduction_ratio)

    if expand_ratio != 1.0:
      layers = [ConvBN(ch_in, ch_med, 1), Swish()]
    else:
      layers = []
    
    layers.extend([ConvBN(ch_med, ch_med, kernel_size, stride=stride, padding=0, groups=ch_med),
                  Swish(),
                  SEblock(ch_med, ch_sq),
                  ConvBN(ch_med, ch_out, 1),
                  ])
    
    if self.use_residual:
      self.drop_connect = DropConnect(drop_connect_rate)

    self.layers = nn.Sequential(*layers)

  def forward(self, x):
    if self.use_residual:
      return x + self.drop_connect(self.layers(x))
    else:
      return self.layers(x)

In [22]:
#@title Flatten {display-mode: "form"}
class Flatten(nn.Module):
  def forward(self, x):
    return x.view(x.shape[0], -1)

In [23]:
#@title LSTM {display-mode: "form"}
class LSTMClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, tagset_size=128):
        super(LSTMClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, x):
        x = x.permute(2,0,1,3)
        x = x.squeeze(dim=3)
        _, lstm_out = self.lstm(x)
        x = self.hidden2tag(lstm_out[0].view(-1, self.hidden_dim))
        return x

In [24]:
#@title MultiModalNet {display-mode: "form"}
class MultiModalNet(nn.Module):
  def __init__(self,
               width_mult=2.0,
               depth_mult=1.0,
               resolution=None,
               dropout_rate=0.2,
               num_1d_features=10,
               num_classes=4,
               input_ch=3):
    super().__init__()
    features = [BMConvBlock(input_ch, 512, expand_ratio=6, stride=(1,257), kernel_size=(3, 257))]


    self.features2d = nn.Sequential(*features)
    self.reshape = nn.Sequential(
        LSTMClassifier(embedding_dim=512, hidden_dim=512)
    )
    self.classifier = nn.Sequential(
        nn.Linear(128*2, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, num_classes),
        nn.Softmax(dim=1)
    )

    self.features1d = nn.Sequential(
        nn.Linear(num_1d_features, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
    )

  def forward(self, x):
    x1=self.features2d(x[1])
    x1=self.reshape(x1)
    x0=self.features1d(x[0])
    x = torch.cat((x0,x1), dim=1)
    x=self.classifier(x)

    return x

# Dataset

In [25]:
RESEARCH_WORK_PATH = "/content/drive/MyDrive/Colab Notebooks/BachelorResearch/"

In [26]:
import shutil
import os


data_dir = "/content/datas/"
if os.path.exists(data_dir):
  shutil.rmtree(data_dir)
for dirs in os.listdir(RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/"):
  if "Q" in dirs:
    shutil.copytree(RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/" + dirs, data_dir + dirs)

for q in "Q1 Q2 Q3 Q4".split(" "):
  d = RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/"  + q + "/"
  print(q, sum(os.path.isfile(os.path.join(d, name)) for name in os.listdir(d)))

for q in "Q1 Q2 Q3 Q4".split(" "):
  d = data_dir + q + "/"
  print(q, sum(os.path.isfile(os.path.join(d, name)) for name in os.listdir(d)))

Q1 9200
Q2 9200
Q3 9200
Q4 9200
Q1 9200
Q2 9200
Q3 9200
Q4 9200


In [27]:
import os
def make_filepath_list(root, train_rate=0.8):
  train_file_list = []
  valid_file_list = []

  for dirs in os.listdir(root):
    if "Q" in dirs:
      file_dir = os.path.join(root, dirs)
      file_list = os.listdir(file_dir)


      for f in file_list:
        split_num = int(f.split(".")[1].split("_")[-1])
        if split_num % 5 == 4:
          valid_file_list.append(os.path.join(root, dirs, f).replace('\\', '/'))
        else:
          train_file_list.append(os.path.join(root, dirs, f).replace('\\', '/'))

  
  return train_file_list, valid_file_list


In [28]:
from torch.utils import data
import numpy as np
import librosa
import pickle
import torch
import cv2

class musicDataset(data.Dataset):

  def __init__(self, file_list, classes, phase='train'):
    self.file_list = file_list
    self.classes = classes
    self.phase = phase

  def __len__(self):
    return len(self.file_list)


  def __getitem__(self, index):
    pickle_path = self.file_list[index]
    x,y = pickle.load(open(pickle_path, mode="rb"))
    return x, y


train_file_list, valid_file_list = make_filepath_list(data_dir)

print('学習データ数 : ', len(train_file_list))
print(train_file_list[0])

print('検証データ数 : ', len(valid_file_list))
print(valid_file_list[0])

q_classes = "Q1 Q2 Q3 Q4".split(" ")

train_dataset = musicDataset(
    file_list=train_file_list, classes=q_classes, phase='train'
)

valid_dataset = musicDataset(
    file_list=valid_file_list, classes=q_classes, phase='valid'
)


index = 0
print("Dataset1 shape:", train_dataset.__getitem__(index)[0][0].size())
print("Dataset2 shape:", train_dataset.__getitem__(index)[0][1].size())
print("Dataset label:", train_dataset.__getitem__(index)[1])

学習データ数 :  29600
/content/datas/Q1/Q1.MT0010489498_3.wav.pickle
検証データ数 :  7200
/content/datas/Q1/Q1.MT0032235381_24.wav.pickle
Dataset1 shape: torch.Size([26])
Dataset2 shape: torch.Size([1, 431, 257])
Dataset label: 0


# DataLoader

In [29]:
# バッチサイズの指定
batch_size = 64

# DataLoaderを作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)

# pin_mem and num_workers are acceleration technique -> https://qiita.com/sugulu_Ogawa_ISID/items/62f5f7adee083d96a587#11-num_workers

# 辞書にまとめる
dataloaders_dict = {
    'train': train_dataloader, 
    'valid': valid_dataloader
}

# 動作確認
# イテレータに変換
batch_iterator = iter(dataloaders_dict['train'])

# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)

print(inputs[0].size())
print(inputs[1].size())
print(labels)


torch.Size([64, 26])
torch.Size([64, 1, 431, 257])
tensor([0, 3, 1, 0, 0, 0, 0, 3, 1, 1, 0, 3, 2, 2, 3, 2, 3, 0, 3, 2, 0, 2, 0, 0,
        1, 0, 1, 1, 3, 1, 1, 3, 3, 2, 3, 2, 2, 1, 0, 1, 1, 2, 0, 2, 0, 3, 1, 1,
        1, 1, 1, 3, 3, 2, 0, 2, 0, 2, 3, 0, 1, 0, 1, 0])


# Optimizer, Criterion

In [30]:
torch_clearlizer = False #@param {type: "boolean"}

if torch_clearlizer:
  del train_dataset
  del valid_dataset
  del train_dataloader
  del valid_dataloader
  del labels
  del loss_hist
  del acc_hist
  del optimizer
  del criterion
  torch.cuda.empty_cache()

In [31]:
from torch import optim

model     = MultiModalNet(input_ch=1, num_classes=4, num_1d_features=train_dataset.__getitem__(index)[0][0].size(0)).to('cuda')
optimizer = optim.SGD(model.parameters(),lr=0.1) 
criterion = nn.CrossEntropyLoss()

SGD -> じわっとloss減ってく


# Training

In [None]:
from tqdm import tqdm
# エポック数
num_epochs = 200


loss_hist = [[],[]]
acc_hist = [[],[]]

for epoch in range(num_epochs):
  print('Epoch {}/{}'.format(epoch+1, num_epochs))
  print('-------------')
  
  for phase in ['train', 'valid']:
      
    if phase == 'train':
      # モデルを訓練モードに設定
      model.train()
    else:
      # モデルを推論モードに設定
      model.eval()
    
    # 損失和
    epoch_loss = 0.0
    # 正解数
    epoch_corrects = 0
    
    # DataLoaderからデータをバッチごとに取り出す
    with tqdm(dataloaders_dict[phase],unit='batch',colour='green' if phase == 'train' else 'red') as pbar:
      pbar.set_description('['+phase+'] Epoch %d'% (epoch+1))
      for inputs, labels in pbar:
        inputs = [inputs[0].to('cuda') , inputs[1].to('cuda')]
        labels = labels.to('cuda')
        # optimizerの初期化
        optimizer.zero_grad()
        
        # 学習時のみ勾配を計算させる設定にする
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          
          # 損失を計算
          loss = criterion(outputs, labels)
          
          # ラベルを予測
          _, preds = torch.max(outputs, 1)
          
          # 訓練時はバックプロパゲーション
          if phase == 'train':
            # 逆伝搬の計算
            loss.backward()
            # パラメータの更新
            optimizer.step()
          
          # イテレーション結果の計算
          # lossの合計を更新
          # PyTorchの仕様上各バッチ内での平均のlossが計算される。
          # データ数を掛けることで平均から合計に変換をしている。
          # 損失和は「全データの損失/データ数」で計算されるため、
          # 平均のままだと損失和を求めることができないため。
          l = loss.item()
          epoch_loss += l * inputs[0].size(0)
          pbar.set_postfix(dict(loss=l))

          # 正解数の合計を更新
          epoch_corrects += torch.sum(preds == labels.data)

    # epochごとのlossと正解率を表示
    epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
    epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
    loss_hist[0].append(epoch_loss) if phase == 'train' else loss_hist[1].append(epoch_loss)
    acc_hist[0].append(epoch_acc)      if phase == 'train' else acc_hist[1].append(epoch_acc)

    print('[{}] Epoch {} Result: Loss: {:.4f} Acc: {:.4f}\n'.format(phase, epoch+1, epoch_loss, epoch_acc))

Epoch 1/200
-------------


[train] Epoch 1: 100%|[32m██████████[0m| 463/463 [04:20<00:00,  1.78batch/s, loss=1.22]


[train] Epoch 1 Result: Loss: 1.3251 Acc: 0.4074



[valid] Epoch 1: 100%|[31m██████████[0m| 113/113 [00:32<00:00,  3.53batch/s, loss=1.38]


[valid] Epoch 1 Result: Loss: 1.2220 Acc: 0.4700

Epoch 2/200
-------------


[train] Epoch 2: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.2]


[train] Epoch 2 Result: Loss: 1.1778 Acc: 0.5423



[valid] Epoch 2: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.76batch/s, loss=1.22]


[valid] Epoch 2 Result: Loss: 1.1478 Acc: 0.5782

Epoch 3/200
-------------


[train] Epoch 3: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.07]


[train] Epoch 3 Result: Loss: 1.1238 Acc: 0.6111



[valid] Epoch 3: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.54batch/s, loss=1.34]


[valid] Epoch 3 Result: Loss: 1.1076 Acc: 0.6262

Epoch 4/200
-------------


[train] Epoch 4: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.11]


[train] Epoch 4 Result: Loss: 1.0827 Acc: 0.6597



[valid] Epoch 4: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.67batch/s, loss=1.08]


[valid] Epoch 4 Result: Loss: 1.0716 Acc: 0.6674

Epoch 5/200
-------------


[train] Epoch 5: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.04]


[train] Epoch 5 Result: Loss: 1.0522 Acc: 0.6920



[valid] Epoch 5: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.67batch/s, loss=1.13]


[valid] Epoch 5 Result: Loss: 1.0347 Acc: 0.7061

Epoch 6/200
-------------


[train] Epoch 6: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.08]


[train] Epoch 6 Result: Loss: 1.0236 Acc: 0.7208



[valid] Epoch 6: 100%|[31m██████████[0m| 113/113 [00:32<00:00,  3.49batch/s, loss=1.11]


[valid] Epoch 6 Result: Loss: 1.0121 Acc: 0.7312

Epoch 7/200
-------------


[train] Epoch 7: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.899]


[train] Epoch 7 Result: Loss: 1.0011 Acc: 0.7430



[valid] Epoch 7: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.75batch/s, loss=1.05]


[valid] Epoch 7 Result: Loss: 0.9903 Acc: 0.7538

Epoch 8/200
-------------


[train] Epoch 8: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=1.03]


[train] Epoch 8 Result: Loss: 0.9796 Acc: 0.7657



[valid] Epoch 8: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.70batch/s, loss=0.827]


[valid] Epoch 8 Result: Loss: 0.9782 Acc: 0.7622

Epoch 9/200
-------------


[train] Epoch 9: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.963]


[train] Epoch 9 Result: Loss: 0.9607 Acc: 0.7836



[valid] Epoch 9: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.59batch/s, loss=0.911]


[valid] Epoch 9 Result: Loss: 0.9960 Acc: 0.7435

Epoch 10/200
-------------


[train] Epoch 10: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.948]


[train] Epoch 10 Result: Loss: 0.9430 Acc: 0.8015



[valid] Epoch 10: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.73batch/s, loss=0.865]


[valid] Epoch 10 Result: Loss: 0.9382 Acc: 0.8044

Epoch 11/200
-------------


[train] Epoch 11: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.85]


[train] Epoch 11 Result: Loss: 0.9315 Acc: 0.8126



[valid] Epoch 11: 100%|[31m██████████[0m| 113/113 [00:29<00:00,  3.81batch/s, loss=0.851]


[valid] Epoch 11 Result: Loss: 0.9676 Acc: 0.7757

Epoch 12/200
-------------


[train] Epoch 12: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.83]


[train] Epoch 12 Result: Loss: 0.9179 Acc: 0.8258



[valid] Epoch 12: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.74batch/s, loss=0.84]


[valid] Epoch 12 Result: Loss: 0.9346 Acc: 0.8087

Epoch 13/200
-------------


[train] Epoch 13: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.86]


[train] Epoch 13 Result: Loss: 0.9094 Acc: 0.8349



[valid] Epoch 13: 100%|[31m██████████[0m| 113/113 [00:32<00:00,  3.52batch/s, loss=0.841]


[valid] Epoch 13 Result: Loss: 0.9103 Acc: 0.8328

Epoch 14/200
-------------


[train] Epoch 14: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.869]


[train] Epoch 14 Result: Loss: 0.8964 Acc: 0.8477



[valid] Epoch 14: 100%|[31m██████████[0m| 113/113 [00:29<00:00,  3.77batch/s, loss=0.808]


[valid] Epoch 14 Result: Loss: 0.9059 Acc: 0.8374

Epoch 15/200
-------------


[train] Epoch 15: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.901]


[train] Epoch 15 Result: Loss: 0.8865 Acc: 0.8578



[valid] Epoch 15: 100%|[31m██████████[0m| 113/113 [00:32<00:00,  3.53batch/s, loss=0.856]


[valid] Epoch 15 Result: Loss: 0.8921 Acc: 0.8522

Epoch 16/200
-------------


[train] Epoch 16: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.925]


[train] Epoch 16 Result: Loss: 0.8770 Acc: 0.8668



[valid] Epoch 16: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.68batch/s, loss=0.814]


[valid] Epoch 16 Result: Loss: 0.8604 Acc: 0.8842

Epoch 17/200
-------------


[train] Epoch 17: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.809]


[train] Epoch 17 Result: Loss: 0.8697 Acc: 0.8754



[valid] Epoch 17: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.63batch/s, loss=0.807]


[valid] Epoch 17 Result: Loss: 0.8525 Acc: 0.8912

Epoch 18/200
-------------


[train] Epoch 18: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.807]


[train] Epoch 18 Result: Loss: 0.8601 Acc: 0.8846



[valid] Epoch 18: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.63batch/s, loss=1.26]


[valid] Epoch 18 Result: Loss: 0.9466 Acc: 0.7939

Epoch 19/200
-------------


[train] Epoch 19: 100%|[32m██████████[0m| 463/463 [04:15<00:00,  1.81batch/s, loss=0.874]


[train] Epoch 19 Result: Loss: 0.8541 Acc: 0.8907



[valid] Epoch 19: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.61batch/s, loss=0.806]


[valid] Epoch 19 Result: Loss: 0.8912 Acc: 0.8521

Epoch 20/200
-------------


[train] Epoch 20: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.873]


[train] Epoch 20 Result: Loss: 0.8484 Acc: 0.8960



[valid] Epoch 20: 100%|[31m██████████[0m| 113/113 [00:31<00:00,  3.59batch/s, loss=0.807]


[valid] Epoch 20 Result: Loss: 0.8487 Acc: 0.8965

Epoch 21/200
-------------


[train] Epoch 21: 100%|[32m██████████[0m| 463/463 [04:14<00:00,  1.82batch/s, loss=0.866]


[train] Epoch 21 Result: Loss: 0.8431 Acc: 0.9015



[valid] Epoch 21: 100%|[31m██████████[0m| 113/113 [00:30<00:00,  3.74batch/s, loss=0.832]


[valid] Epoch 21 Result: Loss: 0.8372 Acc: 0.9062

Epoch 22/200
-------------


[train] Epoch 22:  89%|[32m████████▉ [0m| 414/463 [03:48<00:26,  1.83batch/s, loss=0.838]

# Visualization

In [None]:
from matplotlib import pyplot as plt
# 学習状況のプロット関数の定義
# 分類精度の履歴をプロット
plt.plot([a.cpu() for a in acc_hist[0]])
plt.plot([a.cpu() for a in acc_hist[1]])
plt.title('model accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(['training', 'validation'], loc='lower right')
plt.show()

# 損失関数の履歴をプロット
plt.plot(loss_hist[0])
plt.plot(loss_hist[1])
plt.title('model loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['training', 'validation'], loc='upper right')
plt.show()


# Save the model

In [None]:
model_path = RESEARCH_WORK_PATH + 'models/multimodal_5s_0.5shift_200epoch_fixed_kernel_3.pth'
torch.save(model.state_dict(), model_path)