In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'./')
if project_dir not in sys.path:
    sys.path.append(project_dir)

medmnist_dir = os.path.join(project_dir, 'modules/MedMNIST')
if medmnist_dir not in sys.path:
    sys.path.append(medmnist_dir)

ipdl_dir = os.path.join(project_dir, 'modules/IPDL')
if ipdl_dir not in sys.path:
    sys.path.append(ipdl_dir)    

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import torch

In [None]:
from experiment.classifier import PCAE

pcae_exp = PCAE(os.path.join(project_dir, 'data/PCAE/weights/AE/BREAST.pt'))

# Dataset

In [None]:
import medmnist
from medmnist import INFO

data_flag = 'breastmnist'
download = True

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [None]:
from torchvision.transforms import Compose, Resize, ToTensor

data_transform = Compose([
    Resize((64, 64)),
    ToTensor(),
])

data_flag = 'breastmnist'
download = True


train_dataset = DataClass(split='train', transform=data_transform, download=download)
eval_dataset = DataClass(split='test', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

## Reduce dataset

In [None]:
from imblearn.under_sampling import RandomUnderSampler

sampling_strategies = [{0: 128, 1: 128}, {0: 16, 1: 16}]
datasets = [train_dataset, eval_dataset]

for idx, dataset in enumerate(datasets):
    x = dataset.imgs
    y = dataset.labels

    sampling_strategy = sampling_strategies[idx]
    undersampler = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=123)
    X_resampled, y_resampled = undersampler.fit_resample(x.reshape((x.shape[0], -1)), y.flatten())

    dataset.imgs = X_resampled.reshape((-1, x.shape[1], x.shape[2]))
    dataset.labels = y_resampled
    # dataset.labels = y_resampled.reshape((-1, y.shape[1]))

print('Train Dataset: {} samples'.format(len(train_dataset)))
print('Eval Dataset: {} samples'.format(len(eval_dataset)))

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=32, shuffle=False)

tb_writer = SummaryWriter('logs/{}/CLASS'.format(pcae_exp.model_name))
pcae_exp.train(train_loader, eval_loader, tb_writer, n_epoch=800)

In [None]:
a, b = next(iter(train_loader))

In [None]:
from torch import nn
criterion = nn.CrossEntropyLoss()

In [None]:
a = torch.tensor([[.1, .9], [.9, .1]])
b = torch.tensor([1, 0])
criterion(a, b)