In [None]:
# For Google Colab
# !pip install roboflow
# !pip install -U transformers
# !pip install datasets
# !pip install wandb
# !pip install accelerate -U

In [None]:
import os
import yaml
import json
import wandb
import torch
import shutil
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms


from PIL import Image
from roboflow import Roboflow
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from sklearn.model_selection import train_test_split
from datasets import load_dataset, DatasetDict, load_metric
from transformers import ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer, AutoImageProcessor

In [None]:
def download_roboflow_data(config):
    """
    Download dataset from RoboFlow.
    """
    roboflow_config = config['data']['roboflow']
    roboflow = Roboflow(api_key=roboflow_config["api_key"])
    project = roboflow.workspace(roboflow_config["workspace"]).project(roboflow_config["project"])
    version = project.version(roboflow_config["version"])
    dataset = version.download(model_format=roboflow_config["version_download"])

    dest_path = config['data']['path'] + "/" + dataset.name

    if not os.path.exists(dest_path):
        shutil.move(src=dataset.location, dst=dest_path)

    print(f"Dataset downloaded and extracted to {config['data']['path']}")
    return dataset, dest_path

In [None]:
def load_config(config_path):
    with open(config_path, 'r') as file:
        return yaml.safe_load(file)

In [None]:
def create_transform(aug_config):
    transform_list = []

    # Add transforms based on configuration
    # if 'random_resize_crop' in aug_config:
    #     transform_list.append(transforms.RandomResizedCrop(**aug_config['random_resize_crop']))
    # if 'random_horizontal_flip' in aug_config:
    #     transform_list.append(transforms.RandomHorizontalFlip(aug_config['random_horizontal_flip']))
    # if 'color_jitter' in aug_config:
    #     transform_list.append(transforms.ColorJitter(**aug_config['color_jitter']))
    # if 'random_rotation' in aug_config:
    #     transform_list.append(transforms.RandomRotation(aug_config['random_rotation']))

    # Always include resizing, ToTensor, and normalization
    transform_list.extend([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ])

    return transforms.Compose(transform_list)

In [None]:
def get_transforms(config):
    train_transform = create_transform(config['data']['train_augmentation'])
    val_transform = create_transform(config['data'].get('val_augmentation', {}))

    return train_transform, val_transform

In [None]:
def organize_images_by_class(src_ds_path, ds_final_path):
    # List of subdirectories to process
    sub_dirs = ['train', 'valid', 'test']

    os.makedirs(ds_final_path, exist_ok=True)

    for sub_dir in sub_dirs:
        current_dir = os.path.join(src_ds_path, sub_dir)

        # List all files in the current directory
        files = [f for f in os.listdir(current_dir) if os.path.isfile(os.path.join(current_dir, f))]

        for f in files:
            # Get the first letter of the file
            first_letter = f[0].upper()

            if not first_letter.isalpha():
                continue

            # Create a new directory for this letter if it doesn't exist
            letter_dir = os.path.join(ds_final_path, first_letter)
            if not os.path.exists(letter_dir):
                os.makedirs(letter_dir)

            # Move the file to the new directory
            src_path = os.path.join(current_dir, f)
            dst_path = os.path.join(letter_dir, f)
            shutil.move(src_path, dst_path)

    shutil.rmtree(src_ds_path)

    # Idk why the fuck this script created a copy of the Project dir
    # shutil.rmtree("src/classification/Project")
    print("Image organization complete!")

In [None]:
f_run_config = "config.yml"
f_wandb_config = "wandb.yml"

In [None]:
# Load configuration
config = load_config(f_run_config)
wandb_config = load_config(f_wandb_config)

In [None]:
# Download data from RoboFlow if specified
if config['data'].get('use_roboflow', False):
    _, location = download_roboflow_data(config)

dataset_name = "Guitar-Chords"

organize_images_by_class(location, "datasets/" + dataset_name)

In [None]:
 # Load pre-trained model and processor
model = ViTForImageClassification.from_pretrained(config['model']['pretrained_weights'])
processor = ViTImageProcessor.from_pretrained(config['model']['pretrained_weights'])

In [None]:
# Get transforms
train_transform, base_transform = get_transforms(config)

In [None]:
def preprocess(batch, is_train=True):
    # Resize the images to the desired size
    train_transforms, base_transforms = get_transforms(config)
    if is_train:
        resized_images = [train_transforms(x.convert("RGB")) for x in batch['image']]
    else:
        resized_images = [base_transforms(x.convert("RGB")) for x in batch['image']]

    inputs = processor(resized_images, return_tensors='pt')
    inputs['label'] = batch['label']

    return inputs

In [None]:
# Load the ds
ds = load_dataset("imagefolder", data_dir="datasets/Guitar-Chords")

# Split the data
ds = ds['train'].train_test_split(test_size=0.3, stratify_by_column="label")
ds_test = ds['test'].train_test_split(test_size=0.5, stratify_by_column="label")
ds = DatasetDict({
    'train': ds['train'].with_transform(lambda batch: preprocess(batch, True)),
    'test': ds['test'].with_transform(lambda batch: preprocess(batch, False)),
    'valid': ds['train'].with_transform(lambda batch: preprocess(batch, False))
})

In [None]:
labels = ds['train'].features['label']

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

metric = load_metric("accuracy")

In [None]:
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'

processor = AutoImageProcessor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels.names),
    id2label={str(i): c for i, c in enumerate(labels.names)},
    label2id={c: str(i) for i, c in enumerate(labels.names)},
    ignore_mismatched_sizes=True
)

In [None]:
# Initialize wandb
wandb.require("core")
wandb.init(
    project=wandb_config["project"],
    name=wandb_config['name'] + "-" + wandb.util.generate_id(),
    config=wandb_config,
    entity=wandb_config["entity"]
)

# Define training arguments
training_args = TrainingArguments(
    output_dir=config['training']['output_dir'],
    num_train_epochs=config['training']['num_epochs'],
    per_device_train_batch_size=config['training']['batch_size'],
    per_device_eval_batch_size=config['training']['batch_size'],
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=float(config['training']['learning_rate']),
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="wandb",
    remove_unused_columns=False,
    logging_steps=500,
    save_total_limit=1,
    # fp16=True
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=ds["train"],
    eval_dataset=ds["valid"],
    compute_metrics=compute_metrics,
    tokenizer=processor
)

# Train the model
trainer.train()

# # Save the fine-tuned model
# trainer.save_model(config['training']['final_model_path'])

# Close wandb run
wandb.finish()