In [None]:
import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import transformers
from transformers import Trainer,TrainingArguments,AutoImageProcessor,AutoModelForImageClassification,EarlyStoppingCallback
import datasets
from datasets import load_dataset
from sklearn.metrics import accuracy_score,precision_score
from collections import Counter

In [None]:
gc.collect()

In [None]:
torch.cuda.is_available()

In [None]:
model_id = 'google/vit-base-patch16-224'

In [None]:
dataset = load_dataset('aum27/mars-terrain',split='train')

In [None]:
dataset, dataset[0]

In [None]:
feature_extractor = AutoImageProcessor.from_pretrained(model_id,trust_remote_code=True)

In [None]:
def preprocess(batch):
    img = [x.convert('RGB') for x in batch['image']]
    inps = feature_extractor(img,return_tensors='pt')
    batch['pixel_values'] = inps['pixel_values']
    # batch['label'] = batch['label']
    return batch

In [None]:
dataset = dataset.map(preprocess,batched=True).train_test_split(test_size=0.2, seed=42)

In [None]:
dataset.set_format(type='torch',columns=['pixel_values','label'])

In [None]:
train_labels = [dataset['train'][i]['label'].item() for i in range(len(dataset['train']))]
train_label_counts = Counter(train_labels)

val_labels = [dataset['test'][i]['label'].item() for i in range(len(dataset['test']))]
val_label_counts = Counter(val_labels)

print("Training dataset label counts:")
print(train_label_counts)

print("\nValidation dataset label counts:")
print(val_label_counts)

In [None]:
model = AutoModelForImageClassification.from_pretrained(model_id,num_labels=8,ignore_mismatched_sizes=True,trust_remote_code=True)

In [None]:
device = torch.device('cuda')
device

In [None]:
model.to(device)

In [None]:
training_args = TrainingArguments(
    output_dir='./output_results_vit',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_checkpointing=True,
    do_eval=True,
    evaluation_strategy='epoch',
    num_train_epochs=5,
    save_strategy='epoch',
    learning_rate=6e-5,
    lr_scheduler_type='linear',
    logging_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy'
)

In [None]:
def collate(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }


In [None]:
def compute_metrics(preds):
    labels = preds.label_ids
    preds = np.argmax(preds.predictions,axis=1)
    acc = accuracy_score(labels,preds)

    prec = precision_score(labels,preds,average='weighted')

    return {'accuracy':acc,'precision':prec}

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    compute_metrics=compute_metrics,
    data_collator=collate,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]



)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()