In [1]:
import importlib
import data_utils
importlib.reload(data_utils)
from data_utils import PhonemeDataset
from mlp_mixer import MLPMixer

import numpy as np
import torch 
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as transforms

import pytorch_lightning as pl
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor

torch.autograd.set_detect_anomaly(True)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x14ae453848f0>

# Load data

In [6]:
transform = None
batch_size = 64

train_loader = DataLoader(
    PhonemeDataset(
        data_filename='../Data/Phoneme/train_X.npy',
        label_filename='../Data/Phoneme/train_y.npy',
        transform=transform
    ), 
    batch_size=batch_size, 
    shuffle=True
)

val_loader = DataLoader(
    PhonemeDataset(
        data_filename='../Data/Phoneme/valid_X.npy',
        label_filename='../Data/Phoneme/valid_y.npy',
        transform=None
    ), 
    batch_size=batch_size, 
    shuffle=False
)

test_loader = DataLoader(
    PhonemeDataset(
        data_filename='../Data/Phoneme/test_X.npy',
        label_filename='../Data/Phoneme/test_y.npy',
        transform=None
    ), 
    batch_size=batch_size, 
    shuffle=False
)

# Check a batch of train data
for X, y in train_loader:
    print(f"Shape of X (batch, channels, timesteps): {X.shape}, shape of labels: {len(y)}")
    break

Shape of X (batch, channels, timesteps): torch.Size([64, 11, 220]), shape of labels: 64


# Create mlp-mixer model

In [7]:
# patch_class options are: "sequential1d", "random1d", "cyclical1d"
patch_class = "sequential1d"

mixer = MLPMixer(
    num_classes=39, 
    num_blocks=5, 
    patch_size=10, 
    hidden_dim=128, 
    patch_class='cyclical1d', 
    tokens_mlp_dim=64, 
    channels_mlp_dim=32, 
    padded_length=220,
    p_dropout=0.5,
    lr=3e-3
)

In [8]:
callbacks = [EarlyStopping(monitor="val_loss", patience=10, mode="min")]

mixer_trainer = Trainer(max_epochs=100)

mixer_trainer.fit(
    model=mixer, 
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type             | Params
--------------------------------------------------
0 | loss         | CrossEntropyLoss | 0     
1 | patching     | PatchingClass    | 14.2 K
2 | mixer_blocks | ModuleList       | 92.8 K
3 | layer_norm   | LayerNorm        | 256   
4 | mlp_head     | Sequential       | 5.0 K 
5 | dropout      | Dropout          | 0     
--------------------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.449     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 99: 100%|██████████| 52/52 [00:07<00:00,  7.19it/s, v_num=25, train_loss_step=0.0697, val_loss=14.10, val_acc=0.104, collapse_flg_val=2.48e+3, train_loss_epoch=0.0663, train_acc=0.980, collapse_flg_train=2.49e+3]  

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 52/52 [00:07<00:00,  7.14it/s, v_num=25, train_loss_step=0.0697, val_loss=14.10, val_acc=0.104, collapse_flg_val=2.48e+3, train_loss_epoch=0.0663, train_acc=0.980, collapse_flg_train=2.49e+3]


In [10]:
# Test
mixer_trainer.test(
    dataloaders=test_loader
)

Restoring states from the checkpoint path at /mnt/lustre/koa/koastore/sadow_group/shared/EE645/mlp-mixer-1d-classification/lightning_logs/version_25/checkpoints/epoch=99-step=5200.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /mnt/lustre/koa/koastore/sadow_group/shared/EE645/mlp-mixer-1d-classification/lightning_logs/version_25/checkpoints/epoch=99-step=5200.ckpt
SLURM auto-requeueing enabled. Setting signal handlers.


Testing DataLoader 0:   0%|          | 0/27 [00:00<?, ?it/s]

Testing DataLoader 0: 100%|██████████| 27/27 [00:02<00:00, 12.62it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.10023866593837738
        test_loss           14.173046112060547
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 14.173046112060547, 'test_acc': 0.10023866593837738}]