In [None]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
import cv2
import random
import torch.nn as nn
from transformers import AutoFeatureExtractor
from datasets import load_dataset,load_metric
import shutil
from transformers import SwinForImageClassification, Trainer, TrainingArguments
from transformers import Trainer, TrainingArguments
from PIL import Image

In [None]:
data_path = "data"
data_set = ["train","test"]
original_folder = "/kaggle/input/plants-classification/data"
train_folder = os.path.join(data_path,"train")
test_folder = os.path.join(data_path,"test")

if not os.path.exists(data_path):
    os.mkdir(data_path)
    
for d in data_set:
    if not os.path.exists(os.path.join(data_path,d)):
        os.mkdir(os.path.join(data_path,d))

for class_name in os.listdir(original_folder):
    class_path = os.path.join(original_folder,class_name)
    class_train_path = os.path.join(train_folder,class_name)
    class_test_path = os.path.join(test_folder,class_name)

    if not os.path.exists(class_train_path):
        os.mkdir(class_train_path)
    if not os.path.exists(class_test_path):
        os.mkdir(class_test_path)
    imgs_name = os.listdir(class_path)
    n_train = int(len(imgs_name) * 0.8)
    
    for img_name in imgs_name[:n_train]:
        img_path = os.path.join(class_path,img_name)
        new_img_path = os.path.join(class_train_path,img_name)
        shutil.copy(img_path,new_img_path)
    
    for img_name in imgs_name[n_train:]:
        img_path = os.path.join(class_path,img_name)
        new_img_path = os.path.join(class_test_path,img_name)
        shutil.copy(img_path,new_img_path)
    
        

In [None]:
data = load_dataset("/kaggle/working/data")

In [None]:
model_name= 'microsoft/swin-base-patch4-window7-224'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x.convert('RGB') for x in example_batch['image']], return_tensors='pt')
    inputs['label'] = example_batch['label']
    return inputs

prepared_ds = data.with_transform(transform)

In [None]:
def collate_fn(batch):  
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [None]:
labels = data['train'].features['label'].names

model = SwinForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes = True,
)

In [None]:
batch_size = 16

training_args = TrainingArguments(
    f"swin-finetuned-plants-classification",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=100,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
#     push_to_hub=True,
)


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()


metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
model_ = SwinForImageClassification.from_pretrained("/kaggle/working/swin-finetuned-plants-classification/checkpoint-139")
img = Image.open("/kaggle/working/data/test/31/2.jpg")
img = feature_extractor(img.convert('RGB'), return_tensors="pt")
model_.eval()
with torch.no_grad():
    pred = model_(**img).logits.argmax(-1).item()
    print(model.config.id2label[str(pred)])

In [None]:
!python 