In [1]:
import torch
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import datasets, transforms

from transformers import ViTImageProcessor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "google/vit-large-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

normalize = transforms.Normalize(mean=image_mean, std=image_std)

# train_transforms = transforms.Compose(
#     [
#         transforms.RandomResizedCrop(size),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         normalize,
#     ]
# )
# val_transforms = transforms.Compose(
#     [
#         transforms.Resize(size),
#         transforms.CenterCrop(size),
#         transforms.ToTensor(),
#         normalize,
#     ]
# )
# test_transforms = transforms.Compose(
#     [
#         transforms.Resize(size),
#         transforms.CenterCrop(size),
#         transforms.ToTensor(),
#         normalize,
#     ]
# )

In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

In [15]:
mncai_dataset = datasets.ImageFolder(root="./data/mncai/train", transform=transform)

val_ratio = 0.2  
test_ratio = 0.1
train_size = int((1 - val_ratio - test_ratio) * len(mncai_dataset))
val_size = int(val_ratio * len(mncai_dataset))
test_size = len(mncai_dataset) - train_size - val_size

mncai_train_dataset, mncai_val_dataset, mncai_test_dataset = random_split(mncai_dataset, [train_size, val_size, test_size])

In [16]:
CIFAKE_train_dataset = datasets.ImageFolder(root="./data/CIFAKE/train", transform=transform)
CIFAKE_test_dataset = datasets.ImageFolder(root="./data/CIFAKE/test", transform=transform)

val_ratio = 0.2  
train_size = int((1 - val_ratio) * len(CIFAKE_train_dataset))
val_size = len(CIFAKE_train_dataset) - train_size

CIFAKE_train_dataset, CIFAKE_val_dataset = random_split(CIFAKE_train_dataset, [train_size, val_size])

In [17]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [18]:
def collate_fn(examples):
    pixel_values, labels = zip(*examples)
    pixel_values = torch.stack(pixel_values)
    labels = torch.tensor(labels)
    return {"pixel_values": pixel_values, "labels": labels}

In [19]:
BATCH_SIZE = 128

train_dataset = ConcatDataset([mncai_train_dataset, CIFAKE_train_dataset])
val_dataset = ConcatDataset([mncai_val_dataset, CIFAKE_val_dataset])
test_dataset = ConcatDataset([mncai_test_dataset, CIFAKE_test_dataset])

train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)

In [20]:
label2id = {'FAKE': 0, 'REAL': 1}
id2label = {0: 'FAKE', 1: 'REAL'}

In [21]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_name, 
    num_labels = 2,
    id2label=id2label, 
    label2id=label2id, 
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([2, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
from transformers import TrainingArguments, Trainer

train_args = TrainingArguments(
    output_dir="output-models",
    save_steps=10,              
    eval_steps=10,              
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.01,
)

trainer = Trainer(
    model,
    train_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    tokenizer=processor,
)

trainer.train()

  trainer = Trainer(


Step,Training Loss


KeyboardInterrupt: 