In [38]:
import os
from PIL import Image

paths = ["../data/img/vit/train/0_real/", "../data/img/vit/train/1_fake/", "../data/img/vit/val/0_real/", "../data/img/vit/val/1_fake/"]

problematic = []

for path in paths:
    for filename in os.listdir(path):
        if filename.endswith(".jpg"):
            img = Image.open(path + filename)
            if img.mode != "RGB":
                problematic.append(path + filename)
                print(path + filename)
                img = img.convert("RGB")
                img.save(path + filename)
            
            
        else:
            continue

In [39]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="../data/img/vit")
dataset

Resolving data files: 100%|██████████| 6910/6910 [00:00<00:00, 125127.43it/s]
Resolving data files: 100%|██████████| 1728/1728 [00:00<00:00, 850376.31it/s]
Downloading data files: 100%|██████████| 6910/6910 [00:00<00:00, 178046.83it/s]
Downloading data files: 0it [00:00, ?it/s]
Extracting data files: 0it [00:00, ?it/s]
Downloading data files: 100%|██████████| 1728/1728 [00:00<00:00, 164430.27it/s]
Downloading data files: 0it [00:00, ?it/s]
Extracting data files: 0it [00:00, ?it/s]
Generating train split: 6910 examples [00:00, 41070.81 examples/s]
Generating validation split: 1728 examples [00:00, 40363.31 examples/s]


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 6910
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1728
    })
})

In [41]:
# from PIL import Image
# import numpy as np

# # Assuming you're accessing the first image in the training set
# for i in range(len(dataset['train'])):
#   sample_image = dataset['train'][i]['image']

#   # If the image is a PIL Image
#   if isinstance(sample_image, Image.Image):
#       print(f"Image mode: {sample_image.mode}")
#       if sample_image.mode == 'RGB':
#           print("The image has 3 channels (RGB).")
#       elif sample_image.mode == 'L':
#           print(i + 1)
#           print("The image has 1 channel (grayscale).")
#           break

#   # If the image is a NumPy array
#   elif isinstance(sample_image, np.ndarray):
#       print(f"Image shape: {sample_image.shape}")
#       if len(sample_image.shape) == 3 and sample_image.shape[2] == 3:
#           print("The image has 3 channels (RGB).")
#           break
#       elif len(sample_image.shape) == 2:
#           print("The image has 1 channel (grayscale).")
#           break


In [42]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [43]:
def transform(batch):
  # Take a list of PIL images and turn them to pixel values
  inputs = processor([x for x in batch['image']], return_tensors="pt")

  # Include labels
  inputs['labels'] = batch['label']
  return inputs

In [44]:
train = dataset["train"].with_transform(transform)
validation = dataset["validation"].with_transform(transform)

In [46]:
import torch

def collate_fn(batch):
  return {
    'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
    'labels': torch.tensor([x['labels'] for x in batch])
  }

In [54]:
import numpy

In [47]:
import evaluate
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
  logits, labels = eval_pred
  predictions = np.argmax(logits, axis=-1)
  return metric.compute(predictions=predictions, references=labels)

In [48]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(model_name_or_path, num_labels=2)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
from transformers import TrainingArguments

epochs = 2
warmup_steps = 100
weight_decay = 0.01

training_args = TrainingArguments(
  output_dir='./results',
  num_train_epochs=epochs,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=16,
  evaluation_strategy="steps",
  warmup_steps=warmup_steps,
  weight_decay=weight_decay,
  logging_dir='./logs',
  remove_unused_columns=False
)

In [52]:
from transformers import Trainer

trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=train,
  eval_dataset=validation,
  tokenizer=processor,
  data_collator=collate_fn,
  compute_metrics=compute_metrics
)

In [None]:
trainer.save_model("./model_vit")


In [53]:
trainer.train()

100%|██████████| 432/432 [11:13<00:00,  1.56s/it]

{'train_runtime': 673.7471, 'train_samples_per_second': 20.512, 'train_steps_per_second': 0.641, 'train_loss': 0.1771984100341797, 'epoch': 2.0}





TrainOutput(global_step=432, training_loss=0.1771984100341797, metrics={'train_runtime': 673.7471, 'train_samples_per_second': 20.512, 'train_steps_per_second': 0.641, 'train_loss': 0.1771984100341797, 'epoch': 2.0})

In [57]:
trainer.save_model("./model_transformers")

In [58]:
model = ViTForImageClassification.from_pretrained("./model_transformers")

In [89]:
i = 103
input = torch.tensor(processor(dataset['validation'][i]['image'], return_tensors="pt")['pixel_values'])

outputs = model(input)
predictions = outputs.logits.argmax(dim=-1)
predictions

  input = torch.tensor(processor(dataset['validation'][i]['image'], return_tensors="pt")['pixel_values'])


tensor([0])

tensor([0])