In [1]:
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import get_scheduler

##### Own
import train_utils.cifar_utils as cifar_utils
from train_utils import make_optimizer, get_cfg

from vision_transformer import VisionTransformer
from train_utils import cifar_utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = get_cfg("config/vit_train.yml")

####### Dataset setup #######
dataset_cfg = cfg["cifar_dataset"]

label2id, id2label = cifar_utils.get_label_dicts(dataset_cfg["label_type"])

train_dataloader, validation_dataloader, test_dataloader = cifar_utils.dataloaders_from_cfg(cfg)

model = VisionTransformer(
    image_size=cfg["cifar_dataset"]["image_size"], use_linear_patch=True, num_classes=len(label2id.keys()))

num_epochs = 1 # TODO: set a param in the config file
lr = 0.003
num_training_steps = num_epochs * len(train_dataloader)

optimizer = make_optimizer(optimizer_name='adamw',model=model, lr=0.003)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )
loss_function = nn.CrossEntropyLoss()

# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model = model.to(device)

In [5]:
# Passing single batch to model
batch = next(iter(train_dataloader))
pred = model(batch["pixel_values"])

In [9]:
# ONXX
input_val = batch
torch.onnx.export(model, input_val , "model.onnx", input_names=['pixel_values'], output_names=["label"])

TypeError: VisionTransformer.forward() missing 1 required positional argument: 'x'

In [3]:
# single epoch training loop
model.train()
epoch_train_loss = 0

for batch_idx, batch in enumerate(train_dataloader):
    # transfer batch to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # forward pass and loss calculation
    outputs = model(batch["pixel_values"])
    loss = loss_function(outputs, batch["coarse_label"])
    loss.backward()
    epoch_train_loss += loss.item()

    # backward pass
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
epoch_train_loss/=len(train_dataloader)

# logger.info(f'Validating at epoch {epoch}')
epoch_val_loss = 0
with torch.no_grad():
    model.eval()
    total_examples, correct_predictions= 0.0, 0.0
    for batch_idx, batch in enumerate(train_dataloader):
        # transfer batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(batch["pixel_values"])
        
        epoch_val_loss += loss_function(outputs, batch["coarse_label"]).item() # retrieve only the scalar value
        pred_labels = outputs.argmax(dim=1)
        
        total_examples += float(len(batch['coarse_label']))
        correct_predictions += float((batch["coarse_label"] == pred_labels).sum().item())

    acc = correct_predictions / total_examples
    epoch_val_loss /= len(train_dataloader)

    print(f'-- train loss {train_loss:.3f} -- validation accuracy {acc:.3f} -- validation loss: {epoch_val_loss:.3f}')
    if epoch_val_loss <= best_val_loss and save_model:
        torch.save(model.state_dict(), 'model.pth')
        best_val_loss = epoch_val_loss

In [25]:
epoch_val_loss = 0
correct_predictions = 0
# model = model.train()

In [None]:
print(len(train_dataloader))

In [None]:
correct_predictions = 0
batch = next(iter(train_dataloader))
batch = {k: v.to(device) for k, v in batch.items()}
print(f"batch_size: {batch['pixel_values'].shape}")
print(f"batch labels: {batch['coarse_label']}")

outputs = model(batch["pixel_values"])

loss = loss_function(outputs, batch["coarse_label"])
epoch_val_loss += loss.item() # retrieve only the scalar value
pred_labels = outputs.argmax(dim=1)

print(f"outputs: {outputs.shape}")
print(f"loss: {loss}")
print(f"pred_labels: {pred_labels}")

correct_predictions += int((batch["coarse_label"] == pred_labels).sum().item())
print(f"correct: {correct_predictions}")

In [None]:

for batch_idx, batch in enumerate(train_dataloader):
    print(batch["coarse_label"])
    if batch_idx > 2:
        break
# batch = next(iter(train_dataloader))
# # print((batch))
# outputs = model(batch["pixel_values"])

In [18]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

In [None]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

In [None]:
from vision_transformer import VisionTransformer

model = VisionTransformer()

dummy_input = 