In [18]:
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
def fix_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
fix_seed()

In [20]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.1307, 0.3081)
])

data_path = 'D:/Dropbox/Dropbox/Work/Study/dataset/'
mnist_ds = list(torchvision.datasets.MNIST(data_path,
                                           download=False,
                                           train=True,
                                           transform=transform))

labeled_size = int(len(mnist_ds) * 0.01)
print('labeled 1 % data:', labeled_size)

labeled_ds = mnist_ds[:labeled_size]
unlabeled_ds = mnist_ds[labeled_size:]
print(f'labeled_data: {len(labeled_ds)}, unlabeled_data: {len(unlabeled_ds)}')

mnist_dl = DataLoader(mnist_ds, batch_size=128, shuffle=False)

labeled 1 % data: 600
labeled_data: 600, unlabeled_data: 59400


In [21]:
def train_test_split(labeled_ds):
    '''
    split train_ds and validation_ds with random index of labeled_ds
    '''
    
    size = len(labeled_ds)
    validation_idx = sorted(list(np.random.choice(range(size), int(size * 0.2), replace=False)),
                            reverse=True)
    validation_ds = [labeled_ds[i] for i in validation_idx]
    for i in validation_idx:
        del labeled_ds[i]
    train_ds = labeled_ds
    
    print(f'train data size: {len(train_ds)}, validation data size: {len(validation_ds)}')
    
    train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
    validation_dl = DataLoader(validation_ds, batch_size=8, shuffle=False)
    
    return train_dl, validation_dl

In [22]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, 2, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1, 0),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.fc = nn.Sequential(
            nn.Linear(64 * 5 * 5, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [23]:
def train_val(train_dl, validation_dl, epochs=5):
    '''
    train model, return best validation accuracy
    '''
    
    best_accuracy = []
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        for imgs, labels in train_dl:
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            preds = model(imgs)
            batch_loss = loss_fn(preds, labels)
            train_loss.append(batch_loss.item())
            
            batch_loss.backward()
            optimizer.step()
        train_loss = np.array(train_loss).mean()
        
        model.eval()
        with torch.no_grad():
            class_correct = 0
            val_loss = []
            for imgs, labels in validation_dl:
                imgs, labels = imgs.to(device), labels.to(device)
                
                preds = model(imgs)
                batch_loss = loss_fn(preds, labels)
                val_loss.append(batch_loss.item())
                
                y_preds = torch.max(torch.softmax(preds, dim=1), dim=1)[1]
                c = (labels==y_preds)
                class_correct += c.tolist().count(True)
            val_loss = np.array(val_loss).mean()
            accuracy = class_correct / (validation_dl.batch_size * len(validation_dl))
            best_accuracy.append(accuracy)
            
        print('epoch: {}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}'
              .format(epoch, train_loss, val_loss, accuracy))
            
    best_accuracy = max(best_accuracy)
    
    return best_accuracy

In [24]:
def test_unlabeled(unlabeled_ds, threshold=0.99):
    '''
    if predicted label' confidence is higher then threshold,
    labeling y as predicted y
    '''
    
    unlabeled = torch.stack([i[0] for i in unlabeled_ds])
    unlabeled_dl = DataLoader(unlabeled, batch_size=8)
    
    pseudo_labeled = []
    index = []
    model.eval()
    with torch.no_grad():
        for i, imgs in enumerate(unlabeled_dl):
            imgs = imgs.to(device)
            
            preds = model(imgs)
            preds = torch.softmax(preds, dim=1)
            conf_preds, y_preds = torch.max(preds, dim=1)
            for j, confidence in enumerate(conf_preds):
                if confidence >= threshold:
                    pseudo_labeled.append((imgs[j].cpu(), y_preds[j].item()))
                    idx = (i * unlabeled_dl.batch_size) + j
                    index.append(idx)
    
    print(f'pseudo_labeled size: {len(pseudo_labeled)}')
    
    return pseudo_labeled, index

In [25]:
def update_dataset(pseudo_labeled, index):
    '''
    add pseudo_labeled to labeled_dataset
    '''
    
    for i in sorted(index, reverse=True):
        del unlabeled_ds[i]
    
    labeled_ds.extend(pseudo_labeled)
    
    print(f'labeled size: {len(labeled_ds)}, unlabeled size: {len(unlabeled_ds)}')
    
    return labeled_ds, unlabeled_ds

In [26]:
def evaluation():
    model.eval()
    with torch.no_grad():
        class_correct = 0
        for imgs, labels in mnist_dl:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            y_preds = torch.max(torch.softmax(preds, dim=1), dim=1)[1]
            c = (labels==y_preds)
            class_correct += c.tolist().count(True)
        accuracy = class_correct / (mnist_dl.batch_size * len(mnist_dl))
    print(f'evaluation accuracy: {accuracy:.4f}')
    return accuracy

In [27]:
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [28]:
train_dl, validation_dl = train_test_split(labeled_ds.copy())
train_val(train_dl, validation_dl)
evaluation()
pseudo_labeled, index = test_unlabeled(unlabeled_ds)
labeled_ds, unlabeled_ds = update_dataset(pseudo_labeled, index)

train data size: 480, validation data size: 120
epoch: 1, train_loss: 1.1118, val_loss: 0.5429, val_acc: 0.8250
epoch: 2, train_loss: 0.2300, val_loss: 0.4343, val_acc: 0.8667
epoch: 3, train_loss: 0.0754, val_loss: 0.5073, val_acc: 0.8333
epoch: 4, train_loss: 0.0434, val_loss: 0.4650, val_acc: 0.8750
epoch: 5, train_loss: 0.0073, val_loss: 0.4367, val_acc: 0.8833
evaluation accuracy: 0.8838
pseudo_labeled size: 34277
labeled size: 34877, unlabeled size: 25123


In [29]:
def find_wrong_pseudo_labeled(mnist_ds, labeled_ds):
    mnist_ds.sort(key=lambda x: x[0].sum())
    labeled_ds.sort(key=lambda x: x[0].sum())

    idx = 0
    wrong_label_cnt = 0
    for img, label in tqdm(labeled_ds[labeled_size:]):
        for n, (img2, label2) in enumerate(mnist_ds[idx:]):
            if (img==img2).view(-1).tolist().count(False) == 0:
                idx += n
                if label != label2:
                    wrong_label_cnt += 1
                break
    print('pseudo labeled size: {}, wrong label count: {}, {:.2f}%'
          .format(len(labeled_ds)-labeled_size, wrong_label_cnt, wrong_label_cnt/(len(labeled_ds)-labeled_size)*100))

In [30]:
find_wrong_pseudo_labeled(mnist_ds.copy(), labeled_ds.copy())

  0%|          | 0/34277 [00:00<?, ?it/s]

pseudo labeled size: 34277, wrong label count: 406, 1.18%
