In [1]:
from argparse import ArgumentParser
import numpy as np
import os
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Callback, seed_everything
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset

import json

from pytorch_lightning.loggers import WandbLogger

from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

from PIL import Image
from transformers import ViTFeatureExtractor

from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

from psychloss import RtPsychCrossEntropyLoss

In [2]:
# define dataset classes, pytorch lightning modules

# you might need to change a few things
# like the idx_to_class and viceversa to be like huggingface
#   feature extractor in the dataset, perhaps

In [3]:
def get_class_to_idx():
    return {0: '00003', 1: '00020', 2: '00036', 3: '00052', 4: '00067', 5: '00086', 6: '00108', 7: '00131', 8: '00152', 9: '00168', 10: '00185', 11: '00209', 12: '00225', 13: '00241', 14: '00259', 15: '00277', 16: '00294', 17: '00314', 18: '00332', 19: '00349', 20: '00365', 21: '00383', 22: '00409', 23: '00004', 24: '00021', 25: '00037', 26: '00053', 27: '00068', 28: '00087', 29: '00109', 30: '00132', 31: '00153', 32: '00169', 33: '00186', 34: '00210', 35: '00226', 36: '00242', 37: '00260', 38: '00278', 39: '00295', 40: '00315', 41: '00333', 42: '00350', 43: '00366', 44: '00386', 45: '00410', 46: '00005', 47: '00022', 48: '00039', 49: '00054', 50: '00069', 51: '00088', 52: '00110', 53: '00133', 54: '00154', 55: '00170', 56: '00187', 57: '00211', 58: '00227', 59: '00243', 60: '00262', 61: '00279', 62: '00296', 63: '00316', 64: '00334', 65: '00351', 66: '00367', 67: '00387', 68: '00411', 69: '00006', 70: '00023', 71: '00040', 72: '00055', 73: '00070', 74: '00089', 75: '00111', 76: '00134', 77: '00155', 78: '00171', 79: '00189', 80: '00212', 81: '00228', 82: '00244', 83: '00263', 84: '00280', 85: '00297', 86: '00317', 87: '00335', 88: '00352', 89: '00368', 90: '00389', 91: '00412', 92: '00007', 93: '00024', 94: '00041', 95: '00056', 96: '00071', 97: '00090', 98: '00112', 99: '00135', 100: '00156', 101: '00172', 102: '00192', 103: '00213', 104: '00229', 105: '00245', 106: '00264', 107: '00281', 108: '00298', 109: '00318', 110: '00336', 111: '00353', 112: '00369', 113: '00390', 114: '00413', 115: '00008', 116: '00025', 117: '00042', 118: '00057', 119: '00073', 120: '00091', 121: '00113', 122: '00138', 123: '00157', 124: '00173', 125: '00195', 126: '00214', 127: '00230', 128: '00246', 129: '00265', 130: '00282', 131: '00299', 132: '00320', 133: '00337', 134: '00354', 135: '00370', 136: '00393', 137: '00009', 138: '00026', 139: '00043', 140: '00058', 141: '00074', 142: '00092', 143: '00114', 144: '00141', 145: '00158', 146: '00174', 147: '00197', 148: '00215', 149: '00231', 150: '00247', 151: '00266', 152: '00283', 153: '00302', 154: '00321', 155: '00338', 156: '00355', 157: '00372', 158: '00394', 159: '00010', 160: '00027', 161: '00044', 162: '00059', 163: '00076', 164: '00093', 165: '00115', 166: '00142', 167: '00159', 168: '00175', 169: '00199', 170: '00216', 171: '00233', 172: '00248', 173: '00269', 174: '00284', 175: '00303', 176: '00322', 177: '00339', 178: '00356', 179: '00374', 180: '00395', 181: '00011', 182: '00028', 183: '00045', 184: '00060', 185: '00078', 186: '00095', 187: '00116', 188: '00143', 189: '00161', 190: '00177', 191: '00200', 192: '00217', 193: '00234', 194: '00249', 195: '00270', 196: '00285', 197: '00304', 198: '00323', 199: '00340', 200: '00357', 201: '00375', 202: '00399', 203: '00012', 204: '00029', 205: '00046', 206: '00061', 207: '00079', 208: '00096', 209: '00120', 210: '00144', 211: '00162', 212: '00179', 213: '00202', 214: '00218', 215: '00235', 216: '00250', 217: '00271', 218: '00287', 219: '00305', 220: '00324', 221: '00341', 222: '00358', 223: '00376', 224: '00400', 225: '00013', 226: '00030', 227: '00047', 228: '00062', 229: '00080', 230: '00101', 231: '00121', 232: '00145', 233: '00163', 234: '00180', 235: '00204', 236: '00219', 237: '00236', 238: '00251', 239: '00272', 240: '00289', 241: '00306', 242: '00325', 243: '00343', 244: '00360', 245: '00377', 246: '00402', 247: '00016', 248: '00031', 249: '00048', 250: '00063', 251: '00081', 252: '00102', 253: '00122', 254: '00146', 255: '00164', 256: '00181', 257: '00205', 258: '00220', 259: '00237', 260: '00252', 261: '00273', 262: '00290', 263: '00307', 264: '00327', 265: '00344', 266: '00361', 267: '00378', 268: '00403', 269: '00017', 270: '00032', 271: '00049', 272: '00064', 273: '00083', 274: '00104', 275: '00123', 276: '00147', 277: '00165', 278: '00182', 279: '00206', 280: '00221', 281: '00238', 282: '00253', 283: '00274', 284: '00291', 285: '00308', 286: '00328', 287: '00346', 288: '00362', 289: '00379', 290: '00404', 291: '00018', 292: '00033', 293: '00050', 294: '00065', 295: '00084', 296: '00106', 297: '00127', 298: '00148', 299: '00166', 300: '00183', 301: '00207', 302: '00222', 303: '00239', 304: '00254', 305: '00275', 306: '00292', 307: '00311', 308: '00329', 309: '00347', 310: '00363', 311: '00380', 312: '00405', 313: '00019', 314: '00034', 315: '00051', 316: '00066', 317: '00085', 318: '00107', 319: '00129', 320: '00149', 321: '00167', 322: '00184', 323: '00208', 324: '00223', 325: '00240', 326: '00256', 327: '00276', 328: '00293', 329: '00312', 330: '00331', 331: '00348', 332: '00364', 333: '00381', 334: '00406'}

idx_to_class = get_class_to_idx()

In [4]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)


In [5]:
# # could just call this on the batch during the common_function

# def extract_vit_features(batch):
#     # Take a list of PIL images and turn them to pixel values
#     inputs = feature_extractor([x.cpu() for x in batch['img']], return_tensors='pt')

#     # Don't forget to include the labels!
#     inputs['label'] = batch['label']
#     inputs['rt'] = batch['rt']
#     return inputs

In [6]:

# class CustomDataModule(pl.LightningDataModule):
#     def __init__(self):
#         # as seen on github, you need to call super init here 
#         super().__init__()
        
#         batch_size = 16
#         self.num_labels = 1000
#         json_data_base = '/afs/crc.nd.edu/user/j/jdulay'

#         self.train_known_known_with_rt_path = os.path.join(json_data_base, "train_known_known_with_rt.json")
#         self.valid_known_known_with_rt_path = os.path.join(json_data_base, "valid_known_known_with_rt.json")

#         self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")


# #     def prepare_data(self):
# #         # download
# #         MNIST(self.data_dir, train=True, download=True)
# #         MNIST(self.data_dir, train=False, download=True)

#         normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
#         train_transforms = Compose(
#                 [
#                     RandomResizedCrop(self.feature_extractor.size),
#                     RandomHorizontalFlip(),
#                     ToTensor(),
#                     normalize,
#                 ]
#             )

#         val_transforms = Compose(
#                 [
#                     Resize(self.feature_extractor.size),
#                     CenterCrop(self.feature_extractor.size),
#                     ToTensor(),
#                     normalize,
#                 ]
#             )

#         self.train_known_known_with_rt_dataset = msd_net_dataset(json_path=self.train_known_known_with_rt_path,
#                                                                 transform=train_transforms)

#         # and this one hehe
#         self.valid_known_known_with_rt_dataset = msd_net_dataset(json_path=self.valid_known_known_with_rt_path,
#                                                                 transform=val_transforms) 

#     def train_dataloader(self):
#         train_dataloader = DataLoader(self.train_known_known_with_rt_dataset, batch_size=16)
#         return train_dataloader
        
#     def val_dataloader(self):
#         val_dataloader = DataLoader(self.valid_known_known_with_rt_dataset, batch_size=16)
#         return val_dataloader

#     def test_dataloader(self):
#         val_dataloader = DataLoader(self.valid_known_known_with_rt_dataset, batch_size=16)
#         return val_dataloader





def collate_fn(batch):
    
#     print('batch is ', batch)
#     1/0
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    
#     print("DEBUG: the batch is ", batch)
#     print("DEBUG: the batch[img] is ", batch['img'])

#     imgs = feature_extractor([p for p in pixel_values], return_tensors='pt')
    
#     print(len(pixel_values))
#     print(len(imgs))

#     assert len(pixel_values) == len(imgs)
    
    labels = torch.tensor([x["label"] for x in batch])
    rt = torch.tensor([x["rt"] for x in batch])

    return {"pixel_values": pixel_values, "label": labels, "rt": rt}


class msd_net_dataset(Dataset):
    def __init__(self,
                 json_path,
                 transform):

        with open(json_path) as f:
            data = json.load(f)
        #print("Json file loaded: %s" % json_path)

        self.data = data
        self.transform = transform
        self.random_weight = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[str(idx)]

        # Open the image and do normalization and augmentation
        img = Image.open(item["img_path"])
        img = img.convert('RGB')
        img = self.transform(img)

        # let's do our transform here instead 
#         img = feature_extractor(img, return_tensors='pt')
#         print(img)
#         print(img[0])
#         1/0
        # why is this not returning a pytorch tensor
        
#         preprocess = Compose([ToTensor()])       
#         img = preprocess(img)
        
        
#         print("HERE", type(img))
        # tokenize the image
        # but first, just do it in the collate function
        
        # Deal with reaction times
        if self.random_weight is None:
            # print("Checking whether an RT exists for this image...")
            if item["RT"] != None:
                rt = item["RT"]
            else:
                # print("RT does not exist")
                rt = 0 # for now 
        # No random weights for reaction time
        else:
            pass
        

        # did we ever use the correct label in commited code?
#         orig_label = item['label']
#         re_index_label = idx_to_class[orig_label]
#         print('re_index_label', re_index_label)

        return {
            "pixel_values": img,
            "label": item["label"],
            "rt": rt,
            "category": item["category"]
        }


In [7]:

class ViTLightningModule(pl.LightningModule):
    def __init__(self, vit, testdataset):
        super(ViTLightningModule, self).__init__()
        self.vit = vit
        

        # self.vit = torchvision.models.vit_b_32(pretrained=True) 
        self.num_labels=336
        self.criterion = nn.CrossEntropyLoss()
        self.fc = nn.Linear(4096, 336)
        
        
        self.test_dataset = testdataset
        #   num_labels=10,
        #   id2label=id2label,
        #   label2id=label2id

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
#         outputs = self.fc(outputs)
        return outputs.logits

    def train_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=16, collate_fn=collate_fn)

    def train_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=16, collate_fn=collate_fn)

    def test_dataloader(self):
        # but we don't have this dataset here right now, so we have to add it, as a class call
        return DataLoader(self.test_dataset, batch_size=16, collate_fn=collate_fn)

    
    def common_step(self, batch, batch_idx):
        # TODO: implement w RT 
        
#         feats, labels = extract_vit_features(batch)
        pixel_values = batch['pixel_values']
        labels = batch['label']
        labels = [int(i) for i in labels]
        labels = torch.as_tensor(labels)
        
        
        rts = batch['rt']

#         print("INFO. feats: {} - labels {} - rts {} --".format(pixel_values, labels, rts))
#         print("INFO. feats: {} - labels {} - rts {} --".format(type(pixel_values), type(labels), type(rts)))
                
        #print('beginning of common step')
        #print('pixel_values', pixel_values.shape, pixel_values)
        #print('lbaels,', labels.shape, labels)
    
        logits = self(pixel_values)
        
        
        
        #print('logits', logits.shape, logits)
        
#         print("INFO. logits: {} - labels {} - rts {} --".format(logits, labels, rts))
#         print("INFO. logits: {} - labels {} - rts {} --".format(type(logits), type(labels), type(rts)))

        
        loss = RtPsychCrossEntropyLoss(logits, labels, rts)
#         loss = self.criterion(logits, labels)

        predictions = logits.argmax(-1)
#         print('predictions are', predictions)
#         print('labels are', labels)
        
        #print('debug here')
        correct = (predictions.cpu() == labels.cpu()).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        # print('batch is in testing', batch)
        # 1/0
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_accuracy", accuracy, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)
    

In [8]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

json_data_base = '/afs/crc.nd.edu/user/j/jdulay'
train_known_known_with_rt_path = os.path.join(json_data_base, "train_known_known_with_rt.json")
testdataset = msd_net_dataset(json_path=train_known_known_with_rt_path,
                                        transform=train_transforms)

testdataset

<__main__.msd_net_dataset at 0x14b4f279b358>

In [9]:
# prepare the transforms and dataloader here now:



In [10]:
import os
import wandb

path = '/afa/crc.nd.edu/user/j/jdulay/.cache/'
if os.path.isdir(path):
    os.rmdir(path)
    
# wandb_logger = None
# logger_name = '01_general_model_search_DEBUG'
# wandb_logger = WandbLogger(name=logger_name, project="general_model_search_DEBUG")
vit = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224-in21k', 
#             id2label={str(i): c for i, c in enumerate(labels)},
#             label2id={c: str(i) for i, c in enumerate(labels)}, 
            num_labels=165)

model = ViTLightningModule(vit=vit,testdataset=testdataset)

trainer = pl.Trainer(
    max_epochs=20, 
    devices=1, 
    accelerator='gpu',
    # strategy='ddp',
    auto_select_gpus=True, 
#     logger=wandb_logger,
#     callbacks=[metrics_callback],
    progress_bar_refresh_rate=0,
    limit_train_batches=0,
    limit_val_batches=0
) # hacks 

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Auto select gpus: [0]
  f"Settin

In [11]:
trainer.fit(model)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type                      | Params
--------------------------------------------------------
0 | vit       | ViTForImageClassification | 85.9 M
1 | criterion | CrossEntropyLoss          | 0     
2 | fc        | Linear                    | 1.4 M 
--------------------------------------------------------
87.3 M    Trainable params
0         Non-trainable params
87.3 M    Total params
349.209   Total estimated model params size (MB)


In [12]:
trainer.test(model, ckpt_path=None)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [1,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [2,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [3,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,0], thread: [4,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:91: operator(): block: [0,0,

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.