In [1]:
#step 1 import image
%matplotlib inline
import torchvision.datasets
import math
import torchvision.transforms as tvt
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wget
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as tfms
from torch.utils.data import DataLoader, Subset, Dataset, random_split
from torchvision.utils import make_grid
from PIL import Image
from time import time
from tqdm import tqdm
import random
from transformers import ViTModel

device = torch.device('cuda:1')

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(1024)
image_size = 64
batch_size = 256
dataset = torchvision.datasets.CelebA("../../../../celeba/datasets/",split='train', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])                                  
                              ]))

test_dataset = torchvision.datasets.CelebA("../../../../celeba/datasets/",split='test', transform=tvt.Compose([
                                  tvt.Resize((image_size,image_size)),
                                  tvt.ToTensor(),
                                  tvt.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])                                  
                              ]))

# lengths = [int(len(dataset)*0.9), int(len(dataset)*0.1)]
# if sum(lengths) != len(dataset):
#     lengths[0] += len(dataset) - sum(lengths)
    
# train_dataset, val_dataset = random_split(dataset, lengths)

training_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
print('Done')

  from .autonotebook import tqdm as notebook_tqdm


Done


In [2]:
class VisionTransformer(nn.Module):
    def __init__(self, vit):
        super(VisionTransformer, self).__init__()
        self.vit = vit
        self.seq = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 2)
        )
    
    def forward(self, x):
        z = self.vit(x)
        m = z.last_hidden_state
        g = m[:,0]
        y = self.seq(g)
        return y 

In [3]:
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score
from sklearn.metrics import confusion_matrix
from transformers import ViTConfig, ViTModel

def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_model():
    epoch = 15
    configuration = ViTConfig(num_hidden_layers = 8, num_attention_heads = 8, 
                          intermediate_size = 768, image_size= 64, patch_size = 16)
    vit = ViTModel(configuration)
    configuration = vit.config
    vit = vit.to(device)
    model = VisionTransformer(vit)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    
    valid_acc = []
    valid_eod = []
    save_acc = 0

    for epoches in range(epoch):
        with tqdm(training_data_loader, unit="batch") as tepoch:
            for train_input, attributes in tepoch:
                # Transfer data to GPU if possible. 
                train_input = train_input.to(device)
                train_target = attributes[:,9]
                train_target = torch.nn.functional.one_hot(train_target, num_classes=2).float().to(device)
                optimizer.zero_grad()
                # Learner update step.
                outputs = model(train_input)
                loss = criterion(outputs, train_target)
                loss.backward()
                #logger_learner.add_values(logging_dict)
                optimizer.step()
                tepoch.set_description(f"epoch %2f " % epoches)
                tepoch.set_postfix(ut_loss = loss.item())        
        test_pred = []
        test_gt = [] 

    # Evaluate on valdi set.
        for step, (test_input, attributes) in enumerate(test_data_loader):
            test_target = attributes[:,9]
            test_input = test_input.to(device)
            test_target = test_target.to(device)
            gt = test_target.detach().cpu().numpy()
            test_gt.extend(gt)
            with torch.no_grad():
                test_pred_ = model(test_input)
                _, predicted = torch.max(test_pred_.data, 1)
                test_pred.extend(predicted.cpu().numpy())
        test_acc = accuracy_score(test_gt, test_pred)
        print('acc', test_acc)

        if test_acc > save_acc:
            save_acc = test_acc
            torch.save(model.state_dict(), f'mode.pth')
        
seed_everything(1024)    
train_model()

epoch 0.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:22<00:00,  2.42batch/s, ut_loss=0.127]


acc 0.9388838793708045


epoch 1.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:20<00:00,  2.44batch/s, ut_loss=0.113]


acc 0.9451958721570984


epoch 2.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:25<00:00,  2.39batch/s, ut_loss=0.153]


acc 0.9424406372107004


epoch 3.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:22<00:00,  2.42batch/s, ut_loss=0.177]


acc 0.9481514878268711


epoch 4.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:21<00:00,  2.43batch/s, ut_loss=0.122]


acc 0.9472497745716862


epoch 5.000000 :  18%|████████▍                                     | 117/635 [00:49<03:26,  2.51batch/s, ut_loss=0.126]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

epoch 7.000000 : 100%|██████████████████████████████████████████████| 635/635 [03:57<00:00,  2.67batch/s, ut_loss=0.139]


acc 0.9472497745716862


epoch 8.000000 : 100%|██████████████████████████████████████████████| 635/635 [04:24<00:00,  2.40batch/s, ut_loss=0.112]


acc 0.946899108305781


epoch 9.000000 : 100%|█████████████████████████████████████████████| 635/635 [04:22<00:00,  2.42batch/s, ut_loss=0.0756]


acc 0.9459973950505961


epoch 10.000000 : 100%|█████████████████████████████████████████████| 635/635 [04:19<00:00,  2.45batch/s, ut_loss=0.108]


acc 0.9455465384230037


epoch 11.000000 : 100%|█████████████████████████████████████████████| 635/635 [04:26<00:00,  2.38batch/s, ut_loss=0.106]


acc 0.9447450155295061


epoch 12.000000 : 100%|████████████████████████████████████████████| 635/635 [04:26<00:00,  2.39batch/s, ut_loss=0.0666]


acc 0.9423905420298567


epoch 13.000000 : 100%|████████████████████████████████████████████| 635/635 [04:13<00:00,  2.51batch/s, ut_loss=0.0741]


acc 0.943492636008416


epoch 14.000000 : 100%|████████████████████████████████████████████| 635/635 [04:28<00:00,  2.37batch/s, ut_loss=0.0588]


acc 0.9419897805831079
