In [6]:
import os

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

%matplotlib inline

def create_model():
    """Return a pretrained Resnet model.
    """
    n_classes = 2
    model_ft = models.resnet34(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, n_classes)
    return model_ft

In [7]:
# resize image to 224x224 for resnet18

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# second, fetch raw data
data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [8]:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=15, shuffle=True) for x in ['train', 'val']}
X_train, y_train = next(iter(dataloaders['val']))

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = create_model()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [10]:
n_epochs = 15

for epoch in range(n_epochs):
    # step 1: train on labeled data
    optimizer_ft.zero_grad()
    output = model(X_train)
    real_loss = criterion(output, y_train)
    real_loss.backward()
    optimizer_ft.step()

     # validate
    model.eval()
    X_test, y_test = next(iter(dataloaders['train']))
    test_output = model(X_test)
    test_probs = F.softmax(test_output, dim=1)
    test_preds = test_probs.argmax(dim=1, keepdim=True)
    correct = test_preds.eq(y_test.view_as(test_preds)).sum().item()
    acc = correct / len(y_test)
    print(f"epoch:{epoch} accuracy:{acc}")

epoch:0 accuracy:0.4
epoch:1 accuracy:0.4666666666666667
epoch:2 accuracy:0.8
epoch:3 accuracy:0.7333333333333333
epoch:4 accuracy:0.5333333333333333
epoch:5 accuracy:0.6666666666666666
epoch:6 accuracy:0.7333333333333333
epoch:7 accuracy:0.8666666666666667
epoch:8 accuracy:0.8666666666666667
epoch:9 accuracy:1.0
epoch:10 accuracy:1.0
epoch:11 accuracy:0.9333333333333333
epoch:12 accuracy:0.7333333333333333
epoch:13 accuracy:0.8
epoch:14 accuracy:0.8666666666666667


In [12]:
n_epochs = 4
threshold = .7

for epoch in range(n_epochs):
    
    model.train()
    
    for X_unlabeled, y_unlabeled in dataloaders['train']:

        # step 2: generate pseudo labels from weakly augmented data
        X_weak = transforms.RandomErasing(
            p=1, 
            ratio=(1, 1), 
            scale=(0.01, 0.01), 
            value=.1)(X_unlabeled)
        output = model(X_weak)
        prob = F.softmax(output, dim=1)
                
        # find the indices of samples where prob exceeds confidence threshold
        idx = torch.logical_xor(prob[:,0] > threshold, prob[:,1] > threshold)
        
        # we want to learn from these samples because the model is highly confident
        X_keep = X_unlabeled[idx]
        
        # use argmax of probs as psuedo labels, we might be wrong though
        # y_keep = y_unlabeled[idx]
        idxx = np.where(idx)[0]
        y_keep = torch.argmax(prob, dim=1)[idxx]
        print(prob[idxx], y_keep)
        
        # step 4: generate strongly augmented data
        X_strong = transforms.RandomErasing(
            p=1, 
            ratio=(1, 1), 
            scale=(0.05, 0.05), 
            value=.1)(X_keep)

        # step 5: learn using stronglly augmented data and pseudo labels
        output = model(X_strong)
        fake_loss = criterion(output, y_keep)

        # step 6: backprop
        fake_loss.backward()
        optimizer_ft.step()
        
    # validate
    model.eval()
    X_test, y_test = next(iter(dataloaders['val']))
    test_output = model(X_test)
    test_probs = F.softmax(test_output, dim=1)
    test_preds = test_probs.argmax(dim=1, keepdim=True)
    correct = test_preds.eq(y_test.view_as(test_preds)).sum().item()
    acc = correct / len(y_test)
    print(f"epoch:{epoch} accuracy:{acc}")

tensor([[0.1881, 0.8119],
        [0.7413, 0.2587],
        [0.2314, 0.7686],
        [0.1920, 0.8080],
        [0.7015, 0.2985],
        [0.9310, 0.0690],
        [0.7401, 0.2599],
        [0.1858, 0.8142],
        [0.7373, 0.2627]], grad_fn=<IndexBackward>) tensor([1, 0, 1, 1, 0, 0, 0, 1, 0])
tensor([[0.7795, 0.2205],
        [0.8922, 0.1078],
        [0.2872, 0.7128],
        [0.7152, 0.2848],
        [0.8239, 0.1761],
        [0.7852, 0.2148],
        [0.1151, 0.8849],
        [0.1168, 0.8832],
        [0.8540, 0.1460]], grad_fn=<IndexBackward>) tensor([0, 0, 1, 0, 0, 0, 1, 1, 0])
tensor([[0.8599, 0.1401],
        [0.8379, 0.1621],
        [0.7443, 0.2557],
        [0.0550, 0.9450],
        [0.7767, 0.2233],
        [0.1944, 0.8056],
        [0.7447, 0.2553],
        [0.2516, 0.7484]], grad_fn=<IndexBackward>) tensor([0, 0, 0, 1, 0, 1, 0, 1])
tensor([[0.9371, 0.0629],
        [0.2792, 0.7208],
        [0.7136, 0.2864],
        [0.8176, 0.1824],
        [0.8346, 0.1654],
        [0.