In [7]:
import sys
sys.path.append('/data/vision/beery/scratch/neha/micromamba/envs/datacomp/lib/python3.11/site-packages')
sys.path.append('/data/vision/beery/scratch/neha/task-datacomp/')

from transformers import AutoImageProcessor, AutoModelForImageClassification
from all_datasets.COOS_dataset import COOSDataset
from all_datasets.FMoW_dataset import FMoWDataset
import transformers
import timm
import torch
import peft
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig, PeftModel


In [27]:
model = timm.create_model('vit_base_patch16_224.augreg2_in21k_ft_in1k', pretrained=True)
data_config = timm.data.resolve_model_data_config(model)
# transforms = timm.data.create_transform(**data_config, is_training=False)
preprocess_train = timm.data.create_transform(**data_config, is_training=True, no_aug=True)
preprocess_eval = timm.data.create_transform(**data_config, is_training=False)

class ViTConfig(PretrainedConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)


class ViTModel(PreTrainedModel):
    config_class = ViTConfig

    def __init__(self, model, config):
        super().__init__(config)
        self.model = model
        self.blocks = model.blocks

    def forward(self, x):
        return self.model(x)

In [28]:
def get_lora_model(model):
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.01,
        bias='none',
        target_modules=['qkv'],
        modules_to_save=["classifier"],
    )
    extractor_model = get_peft_model(ViTModel(model, ViTConfig()), config).to('cuda')
    return extractor_model

train_ds=FMoWDataset('train',transform=preprocess_train)
val_ds=FMoWDataset('test1',transform=preprocess_eval)
train_dataloader=DataLoader(train_ds, batch_size=256)
val_dataloader=DataLoader(val_ds, batch_size=256)
model=get_lora_model(model)

In [29]:
train_ds[0]

(tensor([[[-0.0039,  0.4353, -0.1059,  ...,  0.8824,  0.8824,  0.8667],
          [ 0.2000,  0.4824,  0.0353,  ...,  0.8667,  0.8431,  0.8510],
          [ 0.3333,  0.4275, -0.1843,  ...,  0.8196,  0.8588,  0.7176],
          ...,
          [-0.4196, -0.3804, -0.4196,  ..., -0.4196, -0.5059, -0.1843],
          [-0.4510, -0.4353, -0.4118,  ..., -0.3098, -0.5059, -0.2392],
          [-0.4196, -0.4275, -0.3725,  ..., -0.4118, -0.2706, -0.0353]],
 
         [[ 0.1922,  0.7412,  0.0353,  ...,  0.9451,  0.8980,  0.8902],
          [ 0.4039,  0.8118,  0.2627,  ...,  0.9059,  0.8667,  0.8902],
          [ 0.5294,  0.7098,  0.0196,  ...,  0.8353,  0.8902,  0.7882],
          ...,
          [-0.0196,  0.0824,  0.0353,  ..., -0.2784, -0.3725, -0.0196],
          [-0.0745, -0.0431,  0.0039,  ..., -0.1922, -0.3882, -0.0902],
          [-0.0196, -0.0431,  0.0510,  ..., -0.3569, -0.1843,  0.0980]],
 
         [[ 0.2078,  0.8431,  0.0353,  ...,  0.9294,  0.8824,  0.8275],
          [ 0.4588,  0.9137,

In [24]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 10
device='cuda'
model.to(device)
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_dataloader, 0),total=len(train_dataloader)):
        # Get the inputs and labels
        inputs, _, labels, _ = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 100 == 0:  # Print every 100 mini-batches
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0

 11%|█████████████                                                                                                         | 1/9 [00:02<00:21,  2.67s/it]

[1, 1] loss: 0.015


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.39s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[2, 1] loss: 0.056


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.61s/it]

[3, 1] loss: 0.038


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[4, 1] loss: 0.020


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.36s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[5, 1] loss: 0.022


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[6, 1] loss: 0.020


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[7, 1] loss: 0.019


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.60s/it]

[8, 1] loss: 0.019


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:20,  2.59s/it]

[9, 1] loss: 0.019


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.37s/it]
 11%|█████████████                                                                                                         | 1/9 [00:02<00:21,  2.63s/it]

[10, 1] loss: 0.019


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:21<00:00,  2.36s/it]


In [25]:
model.eval()

# Example validation data (replace with your actual validation data)
correct = 0
total = 0
with torch.no_grad():
    for data in val_dataloader:
        images,_, labels,_ = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
correct/total

0.238

0.066