In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import gc
import os 

# clear cuda memory and collect garbage
gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device: ',device)

x=np.load('scaled_spec_resampled_array.npy')
x=x[:,724:1324,:]
y=np.load('labels_array.npy')-1 
x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])

print(x.shape, y.shape)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

train_dataset = MyDataset(x_train, y_train)
test_dataset = MyDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Device:  cuda:0
(1754, 1, 600, 80) (1754,)


In [6]:
from transformers import AutoImageProcessor, ViTForImageClassification
from transformers import Trainer, TrainingArguments
import evaluate
from datasets import load_dataset, DatasetDict

dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})


processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")




In [10]:
def transforms(batch):
    batch["pixel_values"] = processor(batch["pixel_values"], return_tensors="pt").pixel_values
    return batch


In [11]:
def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

In [16]:
from sklearn.metrics import accuracy_score
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    score=accuracy_score(labels, predictions)
    return score


In [12]:
model=ViTForImageClassification.from_pretrained("google/vit-base-patch16-224",num_labels=6,ignore_mismatched_sizes=True)


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
for name, layer in model.named_parameters():
    if not name.startswith("classifier"):
        layer.requires_grad = False

num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_params}, Trainable parameters: {trainable_params}")

Total parameters: 85803270, Trainable parameters: 4614


In [18]:
training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_dir="logs",
    logging_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    report_to='tensorboard',
    metric_for_best_model='accuracy',
    greater_is_better=True,
)


trainer=Trainer(
    model=model,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor,
    compute_metrics=compute_metrics,
)


In [20]:
trainer.train()

  0%|          | 0/528 [00:00<?, ?it/s]

TypeError: tuple indices must be integers or slices, not str