In [1]:
import os
import pandas as pd
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset


class Cub2011(Dataset):
    """
    https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py
    """
    
    base_folder = 'CUB_200_2011/images'
    url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, loader=default_loader, download=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader
        self.train = train

        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1  # Targets start at 1 by default, so shift to 0
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target, 0, 0

In [24]:
import torchvision
from torchvision.transforms import ToTensor, Resize, Compose
import torch 
from torchvision.models import Wide_ResNet50_2_Weights

trans = Wide_ResNet50_2_Weights.DEFAULT.transforms() #Compose([ToTensor(), Resize((128, 128))])

train_data = data = Cub2011("/home/ki/projects/work/papers/logic-anomaly/code/data/CUB_200_2011/", train=True, transform=trans)
test_data = Cub2011("/home/ki/projects/work/papers/logic-anomaly/code/data/CUB_200_2011/", train=False, transform=trans)

Files already downloaded and verified
Files already downloaded and verified


In [25]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, 
                                           shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, 
                                          shuffle=False, num_workers=2)

In [26]:
from pytorch_ood.model import WideResNet
from torchvision.models import wide_resnet50_2


# def override 
def LeNet(num_classes=None, *args, **kwargs):
    model = wide_resnet50_2(num_classes=1000,  weights=Wide_ResNet50_2_Weights.DEFAULT) # ,
    # model = WideResNet(*args, num_classes=1000, **kwargs, pretrained="imagenet32")
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

In [27]:
from torch import nn 
from torch.utils.data import Dataset, DataLoader
from os.path import join
import pandas as pd 
from PIL import Image
from torch.optim import SGD
import seaborn as sb 

device = "cuda"

label_net = LeNet(num_classes=200).to(device)

learning_rate = 1e-3
momentum = 0.9
criterion = nn.CrossEntropyLoss()
optimizer = SGD(label_net.parameters(), lr=learning_rate, momentum=momentum, nesterov=True)

accs = []
running_loss = 0.0

for epoch in range(100):
    
    for i, batch in enumerate(train_loader):
        inputs, labels, colors, shapes  = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = label_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 10:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            inputs, labels, colors, shapes  = batch
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = label_net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += shapes.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the label network on the test images: {correct / total:.2%}')
    accs.append(correct / total)

print('Finished Training LabelNet')

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth" to /home/ki/.cache/torch/hub/checkpoints/wide_resnet50_2-9ba9bcbe.pth


  0%|          | 0.00/263M [00:00<?, ?B/s]

[1,    11] loss: 0.029
[1,   111] loss: 0.263
Accuracy of the label network on the test images: 1.57%
[2,    11] loss: 0.226
[2,   111] loss: 0.246
Accuracy of the label network on the test images: 3.21%
[3,    11] loss: 0.205
[3,   111] loss: 0.208
Accuracy of the label network on the test images: 5.09%
[4,    11] loss: 0.160
[4,   111] loss: 0.149
Accuracy of the label network on the test images: 6.09%
[5,    11] loss: 0.108
[5,   111] loss: 0.100
Accuracy of the label network on the test images: 6.59%
[6,    11] loss: 0.077
[6,   111] loss: 0.070
Accuracy of the label network on the test images: 7.39%
[7,    11] loss: 0.054
[7,   111] loss: 0.051
Accuracy of the label network on the test images: 7.65%
[8,    11] loss: 0.042
[8,   111] loss: 0.039
Accuracy of the label network on the test images: 8.16%
[9,    11] loss: 0.032
[9,   111] loss: 0.031
Accuracy of the label network on the test images: 8.70%
[10,    11] loss: 0.026
[10,   111] loss: 0.024
Accuracy of the label network on t

KeyboardInterrupt: 