In [1]:
import math
import torch
from torch import nn

# Model

In [2]:
class Swish(nn.Module):
  def forward(self, x):
    return x * torch.sigmoid(x)

In [3]:
class FeatureEx1d(nn.Module):
  def __init__(self, input_size):
    super().__init__()
    self.layers = nn.Sequential(
        nn.linear(input_size, 256),
        nn.functional.relu(),
        nn.linear(256, 128),
        nn.functional.relu(),
        nn.linear(128, 64),
        nn.functional.relu()
    )

  def forward(self, x):
    return self.layers(x)

In [4]:
class SEblock(nn.Module):
  def __init__(self, ch_in, ch_sq):
    super().__init__()
    self.se = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(ch_in, ch_sq, 1),
        Swish(),
        nn.Conv2d(ch_sq, ch_in, 1),
    )
    self.se.apply(weight_init)

  def forward(self, x):
    return x * torch.sigmoid(self.se(x))

def weight_init(m):
  if isinstance(m, nn.Conv2d):
    nn.init.kaiming_normal_(m.weight)

  if isinstance(m, nn.Linear):
    nn.init.kaiming_uniform_(m.weight)
    nn.init.zeros_(m.bias)

In [5]:
class ConvBN(nn.Module):
  def __init__(self, ch_in, ch_out, kernel_size, stride=1, padding=0, groups=1):
    super().__init__()
    self.layers=nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size, stride, padding, groups=groups, bias=False),
        nn.BatchNorm2d(ch_out),
    )
    self.layers.apply(weight_init)

  def forward(self, x):
    return self.layers(x)

In [6]:
class DropConnect(nn.Module):
  def __init__(self, drop_rate):
    super().__init__()
    self.drop_rate=drop_rate

  def forward(self, x):
    if self.training:
      keep_rate=1.0-self.drop_rate
      r = torch.rand([x.size(0),1,1,1], dtype=x.dtype).to(x.device)
      r+=keep_rate
      mask=r.floor()
      return x.div(keep_rate)*mask
    else:
      return x

In [7]:
class BMConvBlock(nn.Module):
  def __init__(self,ch_in,ch_out,expand_ratio,stride,kernel_size,reduction_ratio=4,drop_connect_rate=0.2):
    super().__init__()
    self.use_residual = (ch_in==ch_out) & (stride==1)
    ch_med = int(ch_in*expand_ratio)
    ch_sq  = max(1, ch_in//reduction_ratio)

    if expand_ratio != 1.0:
      layers = [ConvBN(ch_in, ch_med, 1), Swish()]
    else:
      layers = []
    
    layers.extend([ConvBN(ch_med, ch_med, kernel_size, stride=stride, padding=(kernel_size)//2, groups=ch_med),
                  Swish(),
                  SEblock(ch_med, ch_sq),
                  ConvBN(ch_med, ch_out, 1),
                  ])
    
    if self.use_residual:
      self.drop_connect = DropConnect(drop_connect_rate)

    self.layers = nn.Sequential(*layers)

  def forward(self, x):
    if self.use_residual:
      return x + self.drop_connect(self.layers(x))
    else:
      return self.layers(x)

In [8]:
class Flatten(nn.Module):
  def forward(self, x):
    return x.view(x.shape[0], -1)

In [9]:
class MultiModalNet(nn.Module):
  def __init__(self,
               width_mult=1.0,
               depth_mult=1.0,
               resolution=None,
               dropout_rate=0.2,
               num_1d_features=10,
               num_classes=4,
               input_ch=3):
    super().__init__()

    # expand_ratio, channel, repeats, stride, kernel_size                   
    settings = [
        [1,  16, 1, 1, 3],  # MBConv1_3x3, SE, 112 -> 112                   
        [6,  24, 2, 2, 3],  # MBConv6_3x3, SE, 112 ->  56                   
        [6,  40, 2, 2, 5],  # MBConv6_5x5, SE,  56 ->  28                   
        [6,  80, 3, 2, 3],  # MBConv6_3x3, SE,  28 ->  14                   
        [6, 112, 3, 1, 5],  # MBConv6_5x5, SE,  14 ->  14                   
        [6, 192, 4, 2, 5],  # MBConv6_5x5, SE,  14 ->   7                   
        [6, 320, 1, 1, 3]   # MBConv6_3x3, SE,   7 ->   7]                  
    ]
    ch_out = int(math.ceil(32*width_mult))
    features = [nn.AdaptiveAvgPool2d(resolution)] if resolution else []
    features.extend([ConvBN(input_ch, ch_out, 3, stride=2), Swish()])     # -> [32, 3, 3, 3]

    ch_in = ch_out
    for t, c, n, s, k in settings:
      ch_out = int(math.ceil(c*width_mult))
      repeats = int(math.ceil(n*depth_mult))
      for i in range(repeats):
        stride = s if i==0 else 1
        features.extend([BMConvBlock(ch_in, ch_out, t, stride, k)])
        ch_in = ch_out

    ch_last = int(math.ceil(1280*width_mult))
    features.extend([ConvBN(ch_in, ch_last, 1), Swish()])

    self.features2d = nn.Sequential(*features)
    self.reshape = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        Flatten(),
        nn.Dropout(dropout_rate),
        nn.Linear(ch_last, 128)
    )
    self.classifier = nn.Sequential(
        nn.Linear(128*2, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, num_classes),
        nn.Softmax()
    )

    self.features1d = nn.Sequential(
        nn.Linear(num_1d_features, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
    )

  def forward(self, x):
    x1=self.features2d(x[1])
    x1=self.reshape(x1)
    x0=self.features1d(x[0])
    x = torch.cat((x0,x1), dim=1)
    x=self.classifier(x)

    return x

# Dataset

In [10]:
RESEARCH_WORK_PATH = "/content/drive/MyDrive/Colab Notebooks/BachelorResearch/"

In [11]:
import shutil
import os


data_dir = "/content/datas/"
if os.path.exists(data_dir):
  shutil.rmtree(data_dir)
for dirs in os.listdir(RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/"):
  if "Q" in dirs:
    shutil.copytree(RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/" + dirs, data_dir + dirs)

for q in "Q1 Q2 Q3 Q4".split(" "):
  d = RESEARCH_WORK_PATH + "MER_audio_taffc_dataset_wav/5s_0.5shift/pickles1/"  + q + "/"
  print(q, sum(os.path.isfile(os.path.join(d, name)) for name in os.listdir(d)))

for q in "Q1 Q2 Q3 Q4".split(" "):
  d = data_dir + q + "/"
  print(q, sum(os.path.isfile(os.path.join(d, name)) for name in os.listdir(d)))

Q1 9200
Q2 9200
Q3 9200
Q4 9200
Q1 9200
Q2 9200
Q3 9200
Q4 9200


In [12]:
import os
def make_filepath_list(root, train_rate=0.8):
  train_file_list = []
  valid_file_list = []

  for dirs in os.listdir(root):
    if "Q" in dirs:
      file_dir = os.path.join(root, dirs)
      file_list = os.listdir(file_dir)


      for f in file_list:
        split_num = int(f.split(".")[1].split("_")[-1])
        if split_num % 5 == 4:
          valid_file_list.append(os.path.join(root, dirs, f).replace('\\', '/'))
        else:
          train_file_list.append(os.path.join(root, dirs, f).replace('\\', '/'))

  
  return train_file_list, valid_file_list


In [13]:
from torch.utils import data
import numpy as np
import librosa
import pickle
import torch
import cv2

class musicDataset(data.Dataset):

  def __init__(self, file_list, classes, phase='train'):
    self.file_list = file_list
    self.classes = classes
    self.phase = phase

  def __len__(self):
    return len(self.file_list)


  def __getitem__(self, index):
    pickle_path = self.file_list[index]
    x,y = pickle.load(open(pickle_path, mode="rb"))
    return x, y


train_file_list, valid_file_list = make_filepath_list(data_dir)

print('学習データ数 : ', len(train_file_list))
print(train_file_list[0])

print('検証データ数 : ', len(valid_file_list))
print(valid_file_list[0])

q_classes = "Q1 Q2 Q3 Q4".split(" ")

train_dataset = musicDataset(
    file_list=train_file_list, classes=q_classes, phase='train'
)

valid_dataset = musicDataset(
    file_list=valid_file_list, classes=q_classes, phase='valid'
)


index = 0
print("Dataset1 shape:", train_dataset.__getitem__(index)[0][0].size())
print("Dataset2 shape:", train_dataset.__getitem__(index)[0][1].size())
print("Dataset label:", train_dataset.__getitem__(index)[1])

学習データ数 :  29600
/content/datas/Q3/Q3.MT0003114552_20.wav.pickle
検証データ数 :  7200
/content/datas/Q3/Q3.MT0009169626_19.wav.pickle
Dataset1 shape: torch.Size([26])
Dataset2 shape: torch.Size([1, 431, 257])
Dataset label: 2


# DataLoader

In [14]:
# バッチサイズの指定
batch_size = 32

# DataLoaderを作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

valid_dataloader = data.DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)

# pin_mem and num_workers are acceleration technique -> https://qiita.com/sugulu_Ogawa_ISID/items/62f5f7adee083d96a587#11-num_workers

# 辞書にまとめる
dataloaders_dict = {
    'train': train_dataloader, 
    'valid': valid_dataloader
}

# 動作確認
# イテレータに変換
batch_iterator = iter(dataloaders_dict['train'])

# 1番目の要素を取り出す
inputs, labels = next(batch_iterator)

print(inputs[0].size())
print(inputs[1].size())
print(labels)


torch.Size([32, 26])
torch.Size([32, 1, 431, 257])
tensor([3, 2, 1, 1, 0, 3, 2, 2, 2, 2, 3, 2, 0, 3, 0, 2, 3, 2, 1, 1, 1, 0, 3, 1,
        2, 0, 2, 3, 2, 2, 2, 1])


# Optimizer, Criterion

In [15]:
torch_clearlizer = False #@param {type: "boolean"}

if torch_clearlizer:
  del train_dataset
  del valid_dataset
  del train_dataloader
  del valid_dataloader
  del labels
  del loss_hist
  del acc_hist
  del optimizer
  del criterion
  torch.cuda.empty_cache()

In [16]:
from torch import optim

model     = MultiModalNet(input_ch=1, num_classes=4, num_1d_features=train_dataset.__getitem__(index)[0][0].size(0)).to('cuda')
optimizer = optim.SGD(model.parameters(),lr=0.1) 
criterion = nn.CrossEntropyLoss()

SGD -> じわっとloss減ってく


# Training

In [None]:
from tqdm import tqdm
# エポック数
num_epochs = 40


loss_hist = [[],[]]
acc_hist = [[],[]]

for epoch in range(num_epochs):
  print('Epoch {}/{}'.format(epoch+1, num_epochs))
  print('-------------')
  
  for phase in ['train', 'valid']:
      
    if phase == 'train':
      # モデルを訓練モードに設定
      model.train()
    else:
      # モデルを推論モードに設定
      model.eval()
    
    # 損失和
    epoch_loss = 0.0
    # 正解数
    epoch_corrects = 0
    
    # DataLoaderからデータをバッチごとに取り出す
    with tqdm(dataloaders_dict[phase],unit='batch',colour='green' if phase == 'train' else 'red') as pbar:
      pbar.set_description('['+phase+'] Epoch %d'% (epoch+1))
      for inputs, labels in pbar:
        inputs = [inputs[0].to('cuda') , inputs[1].to('cuda')]
        labels = labels.to('cuda')
        # optimizerの初期化
        optimizer.zero_grad()
        
        # 学習時のみ勾配を計算させる設定にする
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          
          # 損失を計算
          loss = criterion(outputs, labels)
          
          # ラベルを予測
          _, preds = torch.max(outputs, 1)
          
          # 訓練時はバックプロパゲーション
          if phase == 'train':
            # 逆伝搬の計算
            loss.backward()
            # パラメータの更新
            optimizer.step()
          
          # イテレーション結果の計算
          # lossの合計を更新
          # PyTorchの仕様上各バッチ内での平均のlossが計算される。
          # データ数を掛けることで平均から合計に変換をしている。
          # 損失和は「全データの損失/データ数」で計算されるため、
          # 平均のままだと損失和を求めることができないため。
          l = loss.item()
          epoch_loss += l * inputs[0].size(0)
          pbar.set_postfix(dict(loss=l))

          # 正解数の合計を更新
          epoch_corrects += torch.sum(preds == labels.data)

    # epochごとのlossと正解率を表示
    epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
    epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
    loss_hist[0].append(epoch_loss) if phase == 'train' else loss_hist[1].append(epoch_loss)
    acc_hist[0].append(epoch_acc)      if phase == 'train' else acc_hist[1].append(epoch_acc)

    print('[{}] Epoch {} Result: Loss: {:.4f} Acc: {:.4f}\n'.format(phase, epoch+1, epoch_loss, epoch_acc))

Epoch 1/40
-------------


  input = module(input)
[train] Epoch 1: 100%|[32m██████████[0m| 925/925 [04:06<00:00,  3.75batch/s, loss=1.14]


[train] Epoch 1 Result: Loss: 1.2359 Acc: 0.4770



[valid] Epoch 1: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.58batch/s, loss=1.01]


[valid] Epoch 1 Result: Loss: 1.1481 Acc: 0.5747

Epoch 2/40
-------------


[train] Epoch 2: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=1.06]


[train] Epoch 2 Result: Loss: 1.1158 Acc: 0.6191



[valid] Epoch 2: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.69batch/s, loss=0.965]


[valid] Epoch 2 Result: Loss: 1.0787 Acc: 0.6608

Epoch 3/40
-------------


[train] Epoch 3: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.992]


[train] Epoch 3 Result: Loss: 1.0580 Acc: 0.6837



[valid] Epoch 3: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.52batch/s, loss=0.929]


[valid] Epoch 3 Result: Loss: 1.0302 Acc: 0.7132

Epoch 4/40
-------------


[train] Epoch 4: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.957]


[train] Epoch 4 Result: Loss: 1.0175 Acc: 0.7242



[valid] Epoch 4: 100%|[31m██████████[0m| 225/225 [00:30<00:00,  7.39batch/s, loss=0.939]


[valid] Epoch 4 Result: Loss: 0.9894 Acc: 0.7550

Epoch 5/40
-------------


[train] Epoch 5: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=1.03]


[train] Epoch 5 Result: Loss: 0.9861 Acc: 0.7564



[valid] Epoch 5: 100%|[31m██████████[0m| 225/225 [00:30<00:00,  7.47batch/s, loss=0.938]


[valid] Epoch 5 Result: Loss: 0.9856 Acc: 0.7560

Epoch 6/40
-------------


[train] Epoch 6: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=1.09]


[train] Epoch 6 Result: Loss: 0.9604 Acc: 0.7824



[valid] Epoch 6: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.75batch/s, loss=0.889]


[valid] Epoch 6 Result: Loss: 0.9826 Acc: 0.7589

Epoch 7/40
-------------


[train] Epoch 7: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.962]


[train] Epoch 7 Result: Loss: 0.9364 Acc: 0.8061



[valid] Epoch 7: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.66batch/s, loss=0.868]


[valid] Epoch 7 Result: Loss: 0.9338 Acc: 0.8071

Epoch 8/40
-------------


[train] Epoch 8: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.838]


[train] Epoch 8 Result: Loss: 0.9206 Acc: 0.8226



[valid] Epoch 8: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.68batch/s, loss=0.846]


[valid] Epoch 8 Result: Loss: 0.8945 Acc: 0.8499

Epoch 9/40
-------------


[train] Epoch 9: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.873]


[train] Epoch 9 Result: Loss: 0.9036 Acc: 0.8394



[valid] Epoch 9: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.78batch/s, loss=0.866]


[valid] Epoch 9 Result: Loss: 0.8892 Acc: 0.8558

Epoch 10/40
-------------


[train] Epoch 10: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.894]


[train] Epoch 10 Result: Loss: 0.8863 Acc: 0.8568



[valid] Epoch 10: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.51batch/s, loss=0.876]


[valid] Epoch 10 Result: Loss: 0.8703 Acc: 0.8728

Epoch 11/40
-------------


[train] Epoch 11: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.842]


[train] Epoch 11 Result: Loss: 0.8738 Acc: 0.8697



[valid] Epoch 11: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.60batch/s, loss=0.867]


[valid] Epoch 11 Result: Loss: 0.8500 Acc: 0.8938

Epoch 12/40
-------------


[train] Epoch 12: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.87]


[train] Epoch 12 Result: Loss: 0.8610 Acc: 0.8827



[valid] Epoch 12: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.57batch/s, loss=0.844]


[valid] Epoch 12 Result: Loss: 0.8550 Acc: 0.8885

Epoch 13/40
-------------


[train] Epoch 13: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.781]


[train] Epoch 13 Result: Loss: 0.8500 Acc: 0.8940



[valid] Epoch 13: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.82batch/s, loss=0.84]


[valid] Epoch 13 Result: Loss: 0.8430 Acc: 0.8997

Epoch 14/40
-------------


[train] Epoch 14: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.85]


[train] Epoch 14 Result: Loss: 0.8427 Acc: 0.9013



[valid] Epoch 14: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.76batch/s, loss=0.863]


[valid] Epoch 14 Result: Loss: 0.8433 Acc: 0.9004

Epoch 15/40
-------------


[train] Epoch 15: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.829]


[train] Epoch 15 Result: Loss: 0.8385 Acc: 0.9046



[valid] Epoch 15: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.76batch/s, loss=0.887]


[valid] Epoch 15 Result: Loss: 0.8394 Acc: 0.9046

Epoch 16/40
-------------


[train] Epoch 16: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.816]


[train] Epoch 16 Result: Loss: 0.8340 Acc: 0.9092



[valid] Epoch 16: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.63batch/s, loss=0.843]


[valid] Epoch 16 Result: Loss: 0.8214 Acc: 0.9228

Epoch 17/40
-------------


[train] Epoch 17: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.91]


[train] Epoch 17 Result: Loss: 0.8248 Acc: 0.9182



[valid] Epoch 17: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.78batch/s, loss=0.864]


[valid] Epoch 17 Result: Loss: 0.8179 Acc: 0.9258

Epoch 18/40
-------------


[train] Epoch 18: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.846]


[train] Epoch 18 Result: Loss: 0.8191 Acc: 0.9247



[valid] Epoch 18: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.87batch/s, loss=0.833]


[valid] Epoch 18 Result: Loss: 0.8094 Acc: 0.9353

Epoch 19/40
-------------


[train] Epoch 19: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.782]


[train] Epoch 19 Result: Loss: 0.8157 Acc: 0.9284



[valid] Epoch 19: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.79batch/s, loss=0.837]


[valid] Epoch 19 Result: Loss: 0.8177 Acc: 0.9263

Epoch 20/40
-------------


[train] Epoch 20: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.877]


[train] Epoch 20 Result: Loss: 0.8123 Acc: 0.9313



[valid] Epoch 20: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.78batch/s, loss=0.837]


[valid] Epoch 20 Result: Loss: 0.8035 Acc: 0.9400

Epoch 21/40
-------------


[train] Epoch 21: 100%|[32m██████████[0m| 925/925 [04:06<00:00,  3.76batch/s, loss=0.744]


[train] Epoch 21 Result: Loss: 0.8074 Acc: 0.9362



[valid] Epoch 21: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.61batch/s, loss=0.839]


[valid] Epoch 21 Result: Loss: 0.8068 Acc: 0.9364

Epoch 22/40
-------------


[train] Epoch 22: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.773]


[train] Epoch 22 Result: Loss: 0.8011 Acc: 0.9425



[valid] Epoch 22: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.90batch/s, loss=0.809]


[valid] Epoch 22 Result: Loss: 0.7964 Acc: 0.9476

Epoch 23/40
-------------


[train] Epoch 23: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.808]


[train] Epoch 23 Result: Loss: 0.8010 Acc: 0.9424



[valid] Epoch 23: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.54batch/s, loss=0.81]


[valid] Epoch 23 Result: Loss: 0.7950 Acc: 0.9489

Epoch 24/40
-------------


[train] Epoch 24: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.798]


[train] Epoch 24 Result: Loss: 0.7971 Acc: 0.9463



[valid] Epoch 24: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.89batch/s, loss=0.837]


[valid] Epoch 24 Result: Loss: 0.7918 Acc: 0.9515

Epoch 25/40
-------------


[train] Epoch 25: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.855]


[train] Epoch 25 Result: Loss: 0.7935 Acc: 0.9504



[valid] Epoch 25: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.51batch/s, loss=0.835]


[valid] Epoch 25 Result: Loss: 0.7952 Acc: 0.9487

Epoch 26/40
-------------


[train] Epoch 26: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.775]


[train] Epoch 26 Result: Loss: 0.7948 Acc: 0.9485



[valid] Epoch 26: 100%|[31m██████████[0m| 225/225 [00:30<00:00,  7.41batch/s, loss=0.927]


[valid] Epoch 26 Result: Loss: 0.8038 Acc: 0.9403

Epoch 27/40
-------------


[train] Epoch 27: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.76batch/s, loss=0.751]


[train] Epoch 27 Result: Loss: 0.7899 Acc: 0.9540



[valid] Epoch 27: 100%|[31m██████████[0m| 225/225 [00:30<00:00,  7.37batch/s, loss=0.807]


[valid] Epoch 27 Result: Loss: 0.7869 Acc: 0.9574

Epoch 28/40
-------------


[train] Epoch 28: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.745]


[train] Epoch 28 Result: Loss: 0.7880 Acc: 0.9555



[valid] Epoch 28: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.76batch/s, loss=0.807]


[valid] Epoch 28 Result: Loss: 0.7835 Acc: 0.9606

Epoch 29/40
-------------


[train] Epoch 29: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.774]


[train] Epoch 29 Result: Loss: 0.7863 Acc: 0.9571



[valid] Epoch 29: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.78batch/s, loss=0.807]


[valid] Epoch 29 Result: Loss: 0.7797 Acc: 0.9643

Epoch 30/40
-------------


[train] Epoch 30: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.744]


[train] Epoch 30 Result: Loss: 0.7844 Acc: 0.9595



[valid] Epoch 30: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.55batch/s, loss=0.811]


[valid] Epoch 30 Result: Loss: 0.7845 Acc: 0.9587

Epoch 31/40
-------------


[train] Epoch 31: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.806]


[train] Epoch 31 Result: Loss: 0.7863 Acc: 0.9573



[valid] Epoch 31: 100%|[31m██████████[0m| 225/225 [00:29<00:00,  7.72batch/s, loss=0.807]


[valid] Epoch 31 Result: Loss: 0.7811 Acc: 0.9625

Epoch 32/40
-------------


[train] Epoch 32: 100%|[32m██████████[0m| 925/925 [04:05<00:00,  3.77batch/s, loss=0.744]


[train] Epoch 32 Result: Loss: 0.7831 Acc: 0.9607



[valid] Epoch 32: 100%|[31m██████████[0m| 225/225 [00:28<00:00,  7.89batch/s, loss=0.812]


[valid] Epoch 32 Result: Loss: 0.7829 Acc: 0.9601

Epoch 33/40
-------------


[train] Epoch 33:  58%|[32m█████▊    [0m| 538/925 [02:22<01:42,  3.78batch/s, loss=0.82]

# Visualization

In [None]:
from matplotlib import pyplot as plt
# 学習状況のプロット関数の定義
# 分類精度の履歴をプロット
plt.plot([a.cpu() for a in acc_hist[0]])
plt.plot([a.cpu() for a in acc_hist[1]])
plt.title('model accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(['training', 'validation'], loc='lower right')
plt.show()

# 損失関数の履歴をプロット
plt.plot(loss_hist[0])
plt.plot(loss_hist[1])
plt.title('model loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['training', 'validation'], loc='upper right')
plt.show()


# Save the model

In [None]:
model_path = RESEARCH_WORK_PATH + 'models/multimodal_5s_0.5shift_40epoch.pth'
torch.save(model.state_dict(), model_path)

# 疑問点
入力画素数はどこでわかる？