In [None]:
import os

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image
from torchvision import transforms
from torchvision.utils import make_grid

In [None]:
ROOT_DIR = "data/world-championship-2023-embryo-classification/hvwc23"

In [None]:
train_df = pd.read_csv(os.path.join(ROOT_DIR, "train.csv"))

In [None]:
train_df.head()

In [None]:
train_df["Class"].hist()
plt.title("Class distribution")
plt.show()

In [None]:
test_df = pd.read_csv(os.path.join(ROOT_DIR, "test.csv"))

In [None]:
train_D3_df = train_df[train_df["Image"].str.startswith("D3_")]
train_D5_df = train_df[~train_df["Image"].str.startswith("D3_")]

In [None]:
train_D3_df.shape[0], train_D5_df.shape[0]

In [None]:
test_D3_df = test_df[test_df["Image"].str.startswith("D3_")]
test_D5_df = test_df[~test_df["Image"].str.startswith("D3_")]

In [None]:
test_D3_df.shape[0], test_D5_df.shape[0]

In [None]:
train_D3_df["Class"].hist()
plt.title("Class distribution in day 3")
plt.show()

In [None]:
train_D5_df["Class"].hist()
plt.title("Class distribution in day 5")
plt.show()

Majority of examples from `good` class is from day-5 stage. If we take it into account, e.g. by making the classifier conditioned on a day of an image, the accuracy of our classifier should improve. 

## Visualize examples from both classes on day-3 and day-5 stages

In [None]:
rows_D3_not_good = train_D3_df.loc[train_D3_df["Class"] == 0].sample(10)
rows_D3_good = train_D3_df.loc[train_D3_df["Class"] == 1].sample(10)

images_D3_not_good = [
    Image.open(os.path.join(ROOT_DIR, "train", p)) for p in rows_D3_not_good["Image"].values
]
images_D3_good = [
    Image.open(os.path.join(ROOT_DIR, "train", p)) for p in rows_D3_good["Image"].values
]

### Day-3

In [None]:
fig = plt.figure(figsize=(20.0, 20.0))
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(2, 10),
    axes_pad=0.2,
)

for i, (ax, im) in enumerate(zip(grid, images_D3_not_good + images_D3_good)):
    im = im.resize((224, 224))
    ax.imshow(im)
    ax.axis("off")

    ax.set_title("not good") if i < len(images_D3_not_good) else ax.set_title("good")

plt.show()

### Day-5

In [None]:
rows_D5_not_good = train_D5_df.loc[train_D5_df["Class"] == 0].sample(10)
rows_D5_good = train_D5_df.loc[train_D5_df["Class"] == 1].sample(10)

images_D5_not_good = [
    Image.open(os.path.join(ROOT_DIR, "train", p)) for p in rows_D5_not_good["Image"].values
]
images_D5_good = [
    Image.open(os.path.join(ROOT_DIR, "train", p)) for p in rows_D5_good["Image"].values
]

In [None]:
fig = plt.figure(figsize=(20.0, 20.0))
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(2, 10),
    axes_pad=0.2,
)

for i, (ax, im) in enumerate(zip(grid, images_D5_not_good + images_D5_good)):
    im = im.resize((224, 224))
    ax.imshow(im)
    ax.axis("off")

    ax.set_title("not good") if i < len(images_D5_not_good) else ax.set_title("good")

plt.show()

## Visualize augmentations

In [None]:
image_batch = [
    Image.open(os.path.join(ROOT_DIR, "train", p)) for p in train_df.sample(40)["Image"].values
]

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),
        transforms.RandomPerspective(),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
transformed_batch = [train_transforms(im) for im in image_batch]

In [None]:
plt.figure(figsize=(10, 10))
transformed_grid = make_grid(transformed_batch, nrow=10, normalize=True).permute(1, 2, 0)
plt.imshow(transformed_grid)
plt.show()