In [383]:
from argparse import ArgumentParser
import numpy as np
import os
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Callback, seed_everything
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset

import json

from pytorch_lightning.loggers import WandbLogger

from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

from PIL import Image
from transformers import ViTFeatureExtractor

from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor,
                                    Lambda
                                   )


In [384]:
# reaction time psychophysical loss
def RtPsychCrossEntropyLoss(outputs, targets, psych):
#     print('in psych loss')
#     print(type(targets))
#     print(type(outputs))
#     print('the outputs are', outputs)

    targets = targets.cpu().detach().numpy()

    num_examples = targets.shape[0]
    batch_size = outputs.shape[0]
    
#     print('in loss', targets)
#     new_fucks = []
#     for elem in targets: 
# #         print('the elem is .... ', elem)
#         elem = label2id[elem]
# #         print('the elem is now .... ', elem)
#         new_fucks.append(int(elem))
    
#     targets = np.asarray([id2label[i] for i in targets])
#     targets = torch.as_tensor(targets)
#     print('here', new_fucks)
#     targets = np.asarray(new_fucks)
    
    # converting reaction time to penalty
    # 10002 is close to the max penalty time seen in the data
    for idx in range(len(psych)):   
        psych[idx] = abs(28 - psych[idx]) 
        # seems to be in terms of 10 for now,
        # will fix later

    # adding penalty to each of the output logits 
    for i in range(len(outputs)):
#         print('psych[i]', psych[i])
        val = psych[i] / 30
            
        outputs[i] += val 

    outputs = _log_softmax(outputs)
    outputs = outputs[range(batch_size), targets]

    return - torch.sum(outputs) / num_examples

# mean accuracy psychophysical loss
def AccPsychCrossEntropyLoss(outputs, targets, psych):
    num_examples = targets.shape[0]
    batch_size = outputs.shape[0]

    # converting accuracy to penalty
    for idx in range(len(psych)):   
        psych[idx] = abs(1 - psych[idx])

    for i in range(len(outputs)):
        outputs[i] += (psych[i])

    outputs = _log_softmax(outputs)
    outputs = outputs[range(batch_size), targets]

    return - torch.sum(outputs)/num_examples

def _softmax(x):
    exp_x = torch.exp(x)
    sum_x = torch.sum(exp_x, dim=1, keepdim=True)

    return exp_x/sum_x

def _log_softmax(x):
    return torch.log(_softmax(x))


In [385]:
def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    rt = torch.tensor([x["rt"] for x in batch])

    return {"pixel_values": pixel_values, "label": labels, "rt": rt}

In [386]:
!pip3 install torchmetrics

[33mYou are using pip version 19.0.3, however version 22.1.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [387]:
import torchmetrics

class ViTLightningModule(pl.LightningModule):
    def __init__(self, vit, traindataset,valdataset,testdataset): # TODO:
        super(ViTLightningModule, self).__init__()
        self.vit = vit
        
        # self.vit = torchvision.models.vit_b_32(pretrained=True) 
#         self.num_labels=336
        self.criterion = nn.CrossEntropyLoss()
#         self.fc = nn.Linear(4096, 336)
        self.classifier = nn.Linear(self.vit.config.hidden_size, 164)

        self.accuracy = torchmetrics.Accuracy()
        
        self.train_dataset = traindataset
        self.val_dataset = valdataset
        self.test_dataset = testdataset

    def forward(self, pixel_values):
#         outputs = self.fc(outputs)
        outputs = self.vit(pixel_values=pixel_values, return_dict=False)
        print('type of outputs ', outputs[0].shape)
#         logits = self.classifier(outputs[0])

        return outputs[0]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=8, collate_fn=collate_fn)

    def train_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=8, collate_fn=collate_fn)

    def test_dataloader(self):
        # but we don't have this dataset here right now, so we have to add it, as a class call
        return DataLoader(self.test_dataset, batch_size=8, collate_fn=collate_fn)

    
    def common_step(self, batch, batch_idx):        
        pixel_values = batch['pixel_values']
        labels = batch['label']
        labels = labels.cpu().detach().numpy()

#         temp = []
#         for elem in labels: 
#             elem = label2id[elem]
#             temp.append(int(elem))
    
#         labels = np.asarray(temp)
#         labels = torch.from_numpy(labels).to(self.device)

        rts = batch['rt']

#         print("INFO. feats: {} - labels {} - rts {} --".format(pixel_values, labels, rts))
#         print("INFO. feats: {} - labels {} - rts {} --".format(type(pixel_values), type(labels), type(rts)))
    
    
        print('pixel_values as inputs', type(pixel_values))
        logits = self(pixel_values)
        
#         print("INFO. logits: {} - labels {} - rts {} --".format(logits, labels, rts))
#         print("INFO. ---shapes--- logits: {} - labels {} - rts {} --".format(logits.shape, labels.shape, rts.shape))

#         loss = RtPsychCrossEntropyLoss(logits, labels, rts)
        loss = self.criterion(logits, labels)
        
        labels_hat = torch.argmax(logits, dim=1)
        accuracy = self.accuracy(labels_hat, labels)

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        # print('batch is in testing', batch)
        # 1/0
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_accuracy", accuracy, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)
    

In [388]:
class msd_net_dataset(Dataset):
    def __init__(self,
                 json_path,
                 transform):

        with open(json_path) as f:
            data = json.load(f)
        #print("Json file loaded: %s" % json_path)

        self.data = data
        self.transform = transform
        self.random_weight = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[str(idx)]

        # Open the image and do normalization and augmentation
        img = Image.open(item["img_path"])
        img = img.convert('RGB')
        # needed this transform call
        img = self.transform(img)
        
        # Deal with reaction times
        if item["RT"] != None:
            rt = item["RT"]
        else:
            rt = 0

        return {
            "pixel_values": img,
            "label": item["label"],
            "rt": rt,
            "category": item["category"]
        }


In [389]:
def get_vit_features(img):
#     print('img is', img)
    feats = feature_extractor(img, return_tensors='pt')
#     inputs['labels'] = example['labels']
#     print('inputs is', inputs)
#     print('inputs length is', len(inputs))
#     print('inputs pixels is', inputs['pixel_values'])
#     print('inputs pixels length is', inputs['pixel_values'][0].shape)
    return feats['pixel_values'][0]

In [390]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            Lambda(lambda img: get_vit_features(img)),
#             ToTensor(),
#             normalize,
        ]
    )

#TODO: just split up the dataset fairly 
# maybe just train/test only
json_data_base = '/afs/crc.nd.edu/user/j/jdulay'
train_known_known_with_rt_path = os.path.join(json_data_base, "train_known_known_with_rt.json")
valid_known_known_with_rt_path = os.path.join(json_data_base, "valid_known_known_with_rt.json")

traindataset = msd_net_dataset(json_path=train_known_known_with_rt_path,
                                        transform=train_transforms)


valdataset = traindataset
testdataset = msd_net_dataset(json_path=valid_known_known_with_rt_path,
                                        transform=train_transforms)

testdataset

<__main__.msd_net_dataset at 0x148864672978>

In [391]:
labels = []
#TODO might need to do this for train idk
for i in range(len(testdataset)):
    item = testdataset.data[str(i)]['label']
    if item not in labels:
        labels.append(item)
len(labels)
# so we use this as the classes for now

164

In [392]:
id2label={str(i): c for i, c in enumerate(labels)}
label2id={c: str(i) for i, c in enumerate(labels)}

# id2labe
# label2id?
print(id2label)
print(len(labels))

{'0': 267, '1': 280, '2': 292, '3': 270, '4': 273, '5': 260, '6': 288, '7': 254, '8': 282, '9': 253, '10': 269, '11': 275, '12': 262, '13': 255, '14': 266, '15': 286, '16': 272, '17': 258, '18': 291, '19': 271, '20': 276, '21': 274, '22': 261, '23': 256, '24': 289, '25': 283, '26': 263, '27': 285, '28': 257, '29': 281, '30': 290, '31': 268, '32': 259, '33': 265, '34': 279, '35': 264, '36': 278, '37': 284, '38': 277, '39': 287, '40': 0, '41': 1, '42': 2, '43': 3, '44': 4, '45': 5, '46': 6, '47': 7, '48': 8, '49': 9, '50': 10, '51': 11, '52': 12, '53': 13, '54': 14, '55': 15, '56': 16, '57': 17, '58': 18, '59': 19, '60': 20, '61': 21, '62': 22, '63': 23, '64': 24, '65': 25, '66': 26, '67': 27, '68': 28, '69': 29, '70': 30, '71': 31, '72': 32, '73': 33, '74': 34, '75': 35, '76': 36, '77': 37, '78': 38, '79': 39, '80': 40, '81': 41, '82': 42, '83': 43, '84': 44, '85': 45, '86': 46, '87': 47, '88': 48, '89': 49, '90': 50, '91': 51, '92': 52, '93': 53, '94': 54, '95': 55, '96': 56, '97': 57,

In [393]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
feature_extractor

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

In [394]:
import os
import wandb

path = '/afa/crc.nd.edu/user/j/jdulay/.cache/'
if os.path.isdir(path):
    os.rmdir(path)
    
# wandb_logger = None
logger_name = '01_vit_model_search_DEBUG'
wandb_logger = WandbLogger(name=logger_name, project="general_model_search_DEBUG")
vit = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224-in21k', 
            id2label=id2label,
            label2id=label2id, 
            num_labels=len(id2label), # change here
#             num_labels=335,
        )

model = ViTLightningModule(vit=vit,traindataset=traindataset, valdataset=valdataset, testdataset=testdataset)

print('data len', len(testdataset))

trainer = pl.Trainer(
    max_epochs=20, 
    devices=1, 
#     accelerator='gpu',
    gpus=[3],
#     strategy='ddp',
#     auto_select_gpus=True, 
    logger=wandb_logger,
#     callbacks=[metrics_callback],
    progress_bar_refresh_rate=1000,
    limit_train_batches=0,
    limit_val_batches=0
) # hacks 

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True


data len 9257


In [395]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type                      | Params
---------------------------------------------------------
0 | vit        | ViTForImageClassification | 85.9 M
1 | criterion  | CrossEntropyLoss          | 0     
2 | classifier | Linear                    | 126 K 
3 | accuracy   | Accuracy                  | 0     
---------------------------------------------------------
86.1 M    Trainable params
0         Non-trainable params
86.1 M    Total params
344.204   Total estimated model params size (MB)


In [396]:
!nvidia-smi

Fri Jun 10 11:32:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.05    Driver Version: 510.73.05    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 53%   82C    P2   131W / 250W |  12102MiB / 12288MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 22%   36C    P8    16W / 250W |      3MiB / 12288MiB |      0%      Defaul

In [397]:
trainer.test(model, ckpt_path=None)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 0it [00:00, ?it/s]

pixel_values as inputs <class 'torch.Tensor'>
type of outputs  torch.Size([8, 164])


TypeError: cross_entropy_loss(): argument 'target' (position 2) must be Tensor, not numpy.ndarray

In [None]:
!pwd

In [None]:
# trainer.save_checkpoint("example.ckpt")