In [13]:
import numpy as np
import torch
from torchvision import transforms
from torchvision.models import efficientnet_b4
from torch.utils.data import DataLoader
from datasets import WebvisionDataset, MislabelledDataset
import matplotlib.pyplot as plt
from tqdm import tqdm

In [5]:
device = 'cuda:1'
data_root = 'data/'
webvision_img_size = 227

In [3]:
# prep model
model = efficientnet_b4(pretrained=True)
model = model.to(device)
print("Number of parameters:", sum([p.numel() for p in model.parameters()]))

Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth" to /home/kylew/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-7eb33cd5.pth


  0%|          | 0.00/74.5M [00:00<?, ?B/s]

Number of Parameters: 19341616


In [17]:
# prep dataset and dataloader
im_web_normalize = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
web_test = transforms.Compose([
        transforms.CenterCrop(227),
        transforms.Resize(webvision_img_size),
    ])

train_data = WebvisionDataset(data_root, num_classes=1000, train=True, include_flickr=False, transform=im_web_normalize)
train_dataset = MislabelledDataset(train_data, num_classes=1000, cache=False, transform=web_test)

print("Dataset length:", len(train_dataset))

train_loader = train_loader = DataLoader(train_dataset, 32, shuffle=False, num_workers=8,
                              pin_memory=True, prefetch_factor=2)
print("Number of batches:", len(train_loader))


Dataset length: 980449
Number of batches: 30640


In [18]:
# create ntm
# NTM[i][j] = predicted i by model, labelled j by webvision
NTM = np.zeros((1000, 1000))
model.eval()
loop = tqdm(train_loader, desc=f"Evaluating Images", total=len(train_loader))
for batch_x, batch_y, batch_real, batch_ind in loop:
    if device is not None:
        batch_x = batch_x.to(device)
    
    out = model.forward(batch_x)
    
    preds = out.argmax(dim=-1).detach().cpu().numpy()
    labels = batch_y.detach().cpu().numpy()
    for i in range(len(preds)):
        NTM[preds[i]][labels[i]] += 1

Evaluating Images: 100%|█████████████████████████| 30640/30640 [24:52<00:00, 20.53it/s]


In [20]:
np.savetxt("saved/web_est_NTM.txt", NTM)

In [None]:
np.save()