In [1]:
from huggingface_hub import notebook_login
import wandb

notebook_login()
wandb.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkimitoinf[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import os
import random
import numpy as np
import torch

def set_seed(seed_value = 42):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    random.seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [None]:
from datasets import load_dataset

dataset_path = './disease_data'
for label in os.listdir(dataset_path):
    print(label + ': ' + str(len(os.listdir(os.path.join(dataset_path, label)))))
dataset = load_dataset(dataset_path, split = 'train').shuffle(seed = 42).train_test_split(test_size = 0.3)
split = dataset['test'].train_test_split(test_size = 0.5)
dataset['validation'] = split['train']
dataset['test'] = split['test']
print(dataset)

In [None]:
labels = dataset['train'].features['label'].names
label2id, id2label = dict(), dict()
for loop, label in enumerate(labels):
    label2id[label] = str(loop)
    id2label[str(loop)] = label

In [None]:
from transformers import AutoImageProcessor, DefaultDataCollator
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, RandomVerticalFlip, RandomHorizontalFlip

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

normalize = Normalize(mean = image_processor.image_mean, std = image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), RandomVerticalFlip(), RandomHorizontalFlip(), ToTensor(), normalize])

def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

dataset = dataset.with_transform(transforms)
data_collator = DefaultDataCollator()

print(dataset['train'][0])

In [None]:
import numpy as np
import torch
import evaluate

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis = -1)
    accuracy = evaluate.load('accuracy').compute(predictions = predictions, references = labels)
    f1_precision_recall = evaluate.combine(['f1', 'precision', 'recall']).compute(predictions = predictions, references = labels, average = 'weighted')
    metrics = dict()
    for loop in [accuracy, f1_precision_recall]:
        metrics.update(loop)
    probs = np.array(torch.nn.functional.softmax(torch.tensor(logits), dim = -1).tolist())
    classes = list(id2label.values())
    true_labels = np.array(labels)
    wandb.log({
        'roc': wandb.plot.roc_curve(y_true = true_labels, y_probas = probs, labels = classes),
        'pr': wandb.plot.pr_curve(y_true = true_labels, y_probas = probs, labels = classes)
    })
    # wandb.sklearn.plot_confusion_matrix(y_true = true_labels, y_pred = np.array(predictions), labels = classes) # graph is overlapped.
    return metrics

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from PIL import ImageFile
import os
import torch

ImageFile.LOAD_TRUNCATED_IMAGES = True
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
os.environ['WANDB_PROJECT'] = 'dedc'
torch.cuda.empty_cache()

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id,
)

def train():
    training_args = TrainingArguments(
        output_dir = "./dedc",
        remove_unused_columns = False,
        eval_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate = 5e-5,
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 4,
        per_device_eval_batch_size = 16,
        num_train_epochs = 25,
        warmup_ratio = 0.1,
        logging_steps = 10,
        load_best_model_at_end = True,
        metric_for_best_model = "accuracy",
        push_to_hub = True,
        run_name = 'run',
        report_to = 'wandb'
    )

    trainer = Trainer(
        model = model,
        args = training_args,
        data_collator = data_collator,
        train_dataset = dataset['train'],
        eval_dataset = dataset['validation'],
        tokenizer = image_processor,
        compute_metrics = compute_metrics,
    )

    trainer.train()
    wandb.finish()

train()
model_path = './model'
image_processor.save_pretrained(model_path)
model.save_pretrained(model_path)

In [None]:
from transformers import pipeline
from torchvision.transforms.functional import to_pil_image
from tqdm import notebook
import wandb

wandb.init(project = 'dedc', name = 'test')

model_path = './model'
classifier = pipeline("image-classification", model = model_path)

true_labels = []
pred_labels = []
for loop in notebook.tqdm(dataset['test']):
    prediction = classifier(to_pil_image(loop['pixel_values']))
    true_labels.append(loop['label'])
    pred_labels.append(int(label2id[prediction[0]['label']]))

accuracy = evaluate.load('accuracy').compute(predictions = pred_labels, references = true_labels)
f1_precision_recall = evaluate.combine(['f1', 'precision', 'recall']).compute(predictions = pred_labels, references = true_labels, average = 'weighted')
metrics = dict()
for loop in [accuracy, f1_precision_recall]:
    metrics.update(loop)
wandb.log(metrics)
wandb.sklearn.plot_confusion_matrix(y_true = np.array(true_labels), y_pred = np.array(pred_labels), labels = list(id2label.values()))
wandb.finish()