In [1]:
import os
import glob
import sys

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
ROOT = os.path.abspath('../')
if ROOT not in sys.path:
    sys.path.append(ROOT)

In [3]:
from datasets.wafer import WM811K
from datasets.wafer import get_dataloader
from tasks.classification import Classification
from baselines.wm811k.transforms import CNNWDITransform
from baselines.wm811k.models import CNNWDI
from models.head import GAPClassifier
from utils.metrics import MultiF1Score
from utils.logging import get_logger

In [4]:
SIZE = (96, 96)

In [5]:
train_transform = CNNWDITransform(SIZE, mode='train')
test_transform = CNNWDITransform(SIZE, mode='test')

In [6]:
DEVICE = 'cuda:2'

In [7]:
BATCH_SIZE =256
NUM_WORKERS = 4
EPOCHS = 100

In [8]:
for p in [0.01, 0.05, 0.10, 0.25, 0.50, 1.00]:
    
    print(f"== Label Proportion: {p:.2f}==")
    
    train_set = WM811K(
        root=os.path.join(ROOT, "data/wm811k/labeled/train/"),
        transform=train_transform,
        proportion=p,
    )
    valid_set = WM811K(
        root=os.path.join(ROOT, "data/wm811k/labeled/valid/"),
        transform=test_transform
    )
    test_set = WM811K(
        root=os.path.join(ROOT, "data/wm811k/labeled/test/"),
        transform=test_transform
    )
    
    ckpt_dir = f'../checkpoints/wm811k/baselines/cnnwdi/LP-{p:.2f}'
    
    backbone = CNNWDI(in_channels=2)
    classifier = GAPClassifier(backbone.out_channels, num_classes=9)
    optimizer = optim.Adam([*backbone.parameters(), *classifier.parameters()], lr=0.001)
    cnnwdi = Classification(
        backbone=backbone,
        classifier=classifier,
        optimizer=optimizer,
        scheduler=None,
        loss_function=nn.CrossEntropyLoss(reduction='mean'),
        metrics=dict(f1=MultiF1Score(num_classes=9, average='macro')),
        checkpoint_dir=ckpt_dir,
        write_summary=False,
    )
    cnnwdi.run(
        train_set=train_set,
        valid_set=valid_set,
        test_set=test_set,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        device=DEVICE,
        logger=get_logger(stream=False, logfile=os.path.join(ckpt_dir, 'main.log')),
        eval_metric='f1',
        balance=True,
    )

== Label Proportion: 0.01==
 Epoch: [100/100] ( 56) | train_loss: 0.0653 | valid_loss: 0.3332 | train_f1: 0.9804 | valid_f1: 0.6272 |: 100%|[94m██████████[39m| [05:00<00:00,  3.00s/it]
 Best model (  56):  train_loss: 0.1221 | valid_loss: 0.2557 | test_loss: 0.2648 | train_f1: 0.9600 | valid_f1: 0.6657 | test_f1: 0.6536 |
 Last model ( 100):  train_loss: 0.0653 | valid_loss: 0.3332 | test_loss: 0.3450 | train_f1: 0.9804 | valid_f1: 0.6272 | test_f1: 0.6301 |
== Label Proportion: 0.05==
 Epoch: [100/100] ( 94) | train_loss: 0.1770 | valid_loss: 0.1563 | train_f1: 0.9482 | valid_f1: 0.7807 |: 100%|[94m██████████[39m| [06:13<00:00,  3.74s/it]
 Best model (  94):  train_loss: 0.1499 | valid_loss: 0.1532 | test_loss: 0.1491 | train_f1: 0.9520 | valid_f1: 0.8026 | test_f1: 0.7996 |
 Last model ( 100):  train_loss: 0.1770 | valid_loss: 0.1563 | test_loss: 0.1528 | train_f1: 0.9482 | valid_f1: 0.7807 | test_f1: 0.7900 |
== Label Proportion: 0.10==
 Epoch: [100/100] ( 31) | train_loss: 0.11