In [1]:
!nvidia-smi

Sat Apr 17 19:02:11 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 435.21       Driver Version: 435.21       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce RTX 208...  Off  | 00000000:04:00.0 Off |                  N/A |
| 31%   34C    P0    67W / 250W |      0MiB / 11019MiB |     17%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:05:00.0 Off |                  N/A |
| 30%   19C    P8     4W / 250W |      0MiB / 11019MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce RTX 208...  Off  | 00000000:08:00.0 Off |                  N/A |
| 30%   

In [2]:
train_fp = r'training/labeled'
unlabeled_fp = r'training/unlabeled'
vaild_fp = r'validation'
test_fp = r'training/unlabeled'

In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torchvision
from torch.utils.data import ConcatDataset, DataLoader, TensorDataset
from torchvision.datasets import DatasetFolder
from tqdm.auto import tqdm

train_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.,0.,0.],[1.,1.,1.])

])


test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.,0.,0.],[1.,1.,1.])

])




train_set = DatasetFolder(train_fp, loader = lambda x: Image.open(x),
                          extensions = "jpg", transform = train_tfm)
valid_set = DatasetFolder(vaild_fp, loader = lambda x: Image.open(x),
                          extensions = "jpg", transform = test_tfm)
unlabeled_set = DatasetFolder(unlabeled_fp,
                              loader = lambda x: Image.open(x), extensions = "jpg", transform = train_tfm)
test_set = DatasetFolder(test_fp, loader = lambda x: Image.open(x),
                         extensions = "jpg", transform = test_tfm)

train_loader = DataLoader(train_set, batch_size = 68, shuffle = True, )
valid_loader = DataLoader(valid_set, batch_size = 250, shuffle = True, )
test_loader = DataLoader(test_set, batch_size = 250, shuffle = False)


class contact_set(torch.utils.data.Dataset):
    def __init__(self, old_set, new_set):
        self.old_set = old_set
        self.new_set = new_set

    def __len__(self):
        return len(self.old_set) + len(self.new_set)

    def __getitem__(self, idx):
        if idx < len(self.old_set):
            return self.old_set[idx]
        else:
            return self.new_set[idx - len(self.old_set)]
model = nn.Sequential(torchvision.models.resnet18(pretrained = True),
                     nn.Linear(1000, 200),
                     nn.SELU(),
                     nn.Linear(200, 11))

def get_pseudo_label(old_set, unlabeled_set, model, batch_size = 350, threshold = 0.7):
    data_loader = DataLoader(unlabeled_set, batch_size)
    model.eval()
    softmax = nn.Softmax(-1)
    for img_batch, _ in tqdm(data_loader):
        with torch.no_grad():
            logits = model(img_batch.cuda()).cpu()
        probs = softmax(logits)
        probs_max_bs_1, pos_max_bs_1 = probs.max(-1)
        bool_index = probs_max_bs_1 >= threshold
        new_data = img_batch[bool_index]
        new_label = pos_max_bs_1[bool_index]
        new_set = TensorDataset(new_data, new_label)
        old_set = contact_set(old_set, new_set)
    return old_set


device = "cuda"

model = model.cuda()
model.device = device

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4, weight_decay = 1e-5)
n_epochs = 200

for epoch in range(n_epochs):
    if epoch > 3:

        new_set = get_pseudo_label( train_set,unlabeled_set, model)
        print('New set len: ', len(new_set) - len(train_set))
        train_loader = DataLoader(new_set, batch_size = 300)

    model.train()

    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):
        imgs, labels = batch
        logits = model(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm = 10)
        optimizer.step()
        acc = (logits.argmax(dim = -1) == labels.to(device)).float().mean()
        train_loss.append(loss.item())
        train_accs.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    model.eval()

    valid_loss = []
    valid_accs = []

    for batch in tqdm(valid_loader):
        imgs, labels = batch
        with torch.no_grad():
            logits = model(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        acc = (logits.argmax(dim = -1) == labels.to(device)).float().mean()
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")

model.eval()

predictions = []

for batch in tqdm(test_loader):
    imgs, labels = batch
    with torch.no_grad():
        logits = model(imgs.to(device))
    predictions.extend(logits.argmax(dim = -1).cpu().numpy().tolist())

with open("predict.csv", "w") as f:
    f.write("Id,Category\n")
    for i, pred in enumerate(predictions):
        f.write(f"{i},{pred}\n")

HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))


[ Train | 001/200 ] loss = 0.85429, acc = 0.71503


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 001/200 ] loss = 0.93504, acc = 0.74542


HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))


[ Train | 002/200 ] loss = 0.19388, acc = 0.93689


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 002/200 ] loss = 0.94597, acc = 0.76692


HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))


[ Train | 003/200 ] loss = 0.06476, acc = 0.97673


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 003/200 ] loss = 0.74849, acc = 0.79675


HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))


[ Train | 004/200 ] loss = 0.05226, acc = 0.98517


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 004/200 ] loss = 0.91136, acc = 0.80475


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5879


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 005/200 ] loss = 1.88337, acc = 0.57689


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 005/200 ] loss = 0.68041, acc = 0.81233


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5052


HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))


[ Train | 006/200 ] loss = 1.02336, acc = 0.69095


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 006/200 ] loss = 0.73670, acc = 0.79617


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5697


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 007/200 ] loss = 0.70242, acc = 0.78749


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 007/200 ] loss = 1.08778, acc = 0.69975


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5675


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 008/200 ] loss = 0.57637, acc = 0.82412


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 008/200 ] loss = 1.09636, acc = 0.72750


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5477


HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))


[ Train | 009/200 ] loss = 0.53563, acc = 0.83741


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 009/200 ] loss = 1.35517, acc = 0.60458


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5392


HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))


[ Train | 010/200 ] loss = 0.55417, acc = 0.82636


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 010/200 ] loss = 2.05659, acc = 0.50567


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5412


HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))


[ Train | 011/200 ] loss = 0.62426, acc = 0.81878


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 011/200 ] loss = 2.35789, acc = 0.43408


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5527


HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))


[ Train | 012/200 ] loss = 0.57908, acc = 0.81487


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 012/200 ] loss = 2.32973, acc = 0.46008


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5905


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 013/200 ] loss = 0.38336, acc = 0.88758


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 013/200 ] loss = 2.61254, acc = 0.43617


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6110


HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))


[ Train | 014/200 ] loss = 0.31062, acc = 0.90485


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 014/200 ] loss = 2.85738, acc = 0.41408


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6087


HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))


[ Train | 015/200 ] loss = 0.35394, acc = 0.89611


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 015/200 ] loss = 3.19301, acc = 0.37342


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5791


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 016/200 ] loss = 0.31031, acc = 0.90308


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 016/200 ] loss = 3.16126, acc = 0.41442


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  5918


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


[ Train | 017/200 ] loss = 0.27409, acc = 0.92043


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 017/200 ] loss = 3.72380, acc = 0.36617


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6056


HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))


[ Train | 018/200 ] loss = 0.23131, acc = 0.93176


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 018/200 ] loss = 3.73521, acc = 0.37767


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6343


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))


[ Train | 019/200 ] loss = 0.14382, acc = 0.95912


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 019/200 ] loss = 3.42397, acc = 0.39217


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6456


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))


[ Train | 020/200 ] loss = 0.09637, acc = 0.97408


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 020/200 ] loss = 3.99026, acc = 0.41133


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


New set len:  6577


HBox(children=(FloatProgress(value=0.0, max=33.0), HTML(value='')))


[ Train | 021/200 ] loss = 0.06946, acc = 0.98091


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


[ Valid | 021/200 ] loss = 3.87935, acc = 0.38500


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

KeyboardInterrupt: 