In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from PIL import Image

# This is for the progress bar.
from tqdm.auto import tqdm

In [6]:
# parameters
# The number of training epochs and patience.
batch_size = 4
n_epochs = 20
# n_epochs = 0  # to get statistics (control whether to train)
patience = 5  # If no improvement in 'patience' epochs, early stop

## blur
### b15
# train_set_dir = "blur/15/train"
# test_set_dir = "blur/15/test"
# model_path = "./model_b15.ckpt"

### b45
# train_set_dir = "blur/45/train/"
# test_set_dir = "blur/45/test/"
# model_path = "./model_b45.ckpt"

### b99
# train_set_dir = "blur/99/train/"
# test_set_dir = "blur/99/test/"
# model_path = "./model_b99.ckpt"

## pixel
### p4
train_set_dir = "pixel/4/train/"
test_set_dir = "pixel/4/test/"
model_path = "./model_p4.ckpt"

### p8
# train_set_dir = "pixel/8/train/"
# test_set_dir = "pixel/8/test/"
# model_path = "./model_p8.ckpt"

### p16
# train_set_dir = "pixel/16/train/"
# test_set_dir = "pixel/16/test/"
# model_path = "./model_p16.ckpt"

In [7]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((128,128)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [8]:
class mydataset(Dataset):
    def __init__(self, path, tfm=transform, files=None):
        super(mydataset).__init__()
        self.path = path
        self.files = sorted(
            [os.path.join(path, x) for x in os.listdir(path) if x.endswith(".png")]
        )
        if files:
            self.files = files
        print(f"One {path} sample", self.files[0])
        self.transform = tfm

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        # im = self.data[idx]
        try:
            label = int(fname.split("s")[-1].split("_")[0])
        except:
            label = -1  # test has no label
        return im, label

In [9]:
train_set = mydataset(train_set_dir, tfm=transform)
test_set = mydataset(test_set_dir, tfm=transform)

train_loader = DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = DataLoader(
    test_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)

One pixel/4/train/ sample pixel/4/train/s10_1.png
One pixel/4/test/ sample pixel/4/test/s10_2.png


In [10]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [64, 64, 64]
            nn.Conv2d(64, 128, 3, 1, 1),  # [128, 64, 64]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [128, 32, 32]
            nn.Conv2d(128, 256, 3, 1, 1),  # [256, 32, 32]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [256, 16, 16]
            nn.Conv2d(256, 512, 3, 1, 1),  # [512, 16, 16]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [512, 8, 8]
            nn.Conv2d(512, 512, 3, 1, 1),  # [512, 8, 8]
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),  # [512, 4, 4]
        )
        self.fc = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 41),
        )

    def forward(self, x):
        out = self.cnn(x)
        out = out.view(out.size()[0], -1)
        return self.fc(out)


net = Net()

In [12]:
# "cuda" only when GPUs are available.
torch.cuda.is_available = lambda: False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Initialize a model, and put it on the device specified.
model = Net().to(device)

# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()

# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

# Initialize trackers, these are not parameters and should not be changed
stale = 0
best_acc = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # imgs = imgs.half()
        # print(imgs.shape,labels.shape)

        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Update the parameters with computed gradients.
        optimizer.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(
        f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}"
    )

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    # These are used to record information in validation.
    test_loss = []
    test_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(test_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        # imgs = imgs.half()

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))

        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # compute f1-score

        # compute precision

        # Record the loss and accuracy.
        test_loss.append(loss.item())
        test_accs.append(acc)
        # break

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    test_loss = sum(test_loss) / len(test_loss)
    test_acc = sum(test_accs) / len(test_accs)

    # Print the information.
    print(
        f"[ Test | {epoch + 1:03d}/{n_epochs:03d} ] loss = {test_loss:.5f}, acc = {test_acc:.5f}"
    )

    # save models
    if test_acc > best_acc:
        stale = 0
        best_acc = test_acc
        torch.save(model.state_dict(), model_path)
        print("saving model with acc {:.3f}".format(best_acc))
    else:
        stale += 1
        if stale > patience:
            print(f"No improvment {patience} consecutive epochs, early stopping")
            break

    # if not validating, save the last epoch
    if len(test_loader) == 0:
        torch.save(model.state_dict(), model_path)
        print("saving model at last epoch")

100%|██████████| 96/96 [01:01<00:00,  1.56it/s]


[ Train | 001/020 ] loss = 3.53252, acc = 0.08594


100%|██████████| 33/33 [00:05<00:00,  5.57it/s]


[ Test | 001/020 ] loss = 2.81106, acc = 0.19697
saving model with acc 0.197


100%|██████████| 96/96 [01:06<00:00,  1.45it/s]


[ Train | 002/020 ] loss = 1.96475, acc = 0.42188


100%|██████████| 33/33 [00:05<00:00,  5.80it/s]


[ Test | 002/020 ] loss = 1.23108, acc = 0.62879
saving model with acc 0.629


100%|██████████| 96/96 [01:02<00:00,  1.53it/s]


[ Train | 003/020 ] loss = 0.96824, acc = 0.70833


100%|██████████| 33/33 [00:06<00:00,  5.25it/s]


[ Test | 003/020 ] loss = 0.70484, acc = 0.77273
saving model with acc 0.773


  9%|▉         | 9/96 [00:06<01:00,  1.44it/s]

In [81]:
from sklearn.metrics import classification_report

model_best = Net().to(device)
model_best.load_state_dict(torch.load(model_path))
model_best.eval()
label_pred = []
label_true = []

with torch.no_grad():
    for data, labels in test_loader:
        test_pred = model_best(data.to(device))
        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
        label_pred += test_label.squeeze().tolist()
        label_true += labels.squeeze().tolist()

    report = classification_report(
        label_true, label_pred, labels=[i for i in range(1, 41)], zero_division=0
    )
    print(report)
    with open("".join(model_path[2:].split(".")[:-1]) + "_report.json", "w") as f:
        f.write(report)
