In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import v2
from tqdm import tqdm

In [2]:
BATCH_SIZE = 16
EPOCHS = 20

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class CustomDataset(Dataset):
	def __init__(self, root_dir, transform=None, train=True, split_ratio=0.9):
		self.root_dir = root_dir
		self.transform = transform
		self.train = train
		self.images = []
		self.labels = []
		self.classes = ['brightness', 'invert']

		for idx, cls in enumerate(self.classes):
			cls_folder = os.path.join(root_dir, cls)
			for img_name in os.listdir(cls_folder):
				self.images.append(os.path.join(cls_folder, img_name))
				self.labels.append(idx)

		dataset_size = len(self.images)
		split = int(np.floor(split_ratio * dataset_size))
		indices = list(range(dataset_size))
		np.random.shuffle(indices)
		if self.train:
			self.images = [self.images[i] for i in indices[:split]]
			self.labels = [self.labels[i] for i in indices[:split]]
		else:
			self.images = [self.images[i] for i in indices[split:]]
			self.labels = [self.labels[i] for i in indices[split:]]

	def __len__(self):
		return len(self.images)

	def __getitem__(self, idx):
		image = Image.open(self.images[idx]).convert('RGB')
		if self.transform: image = self.transform(image)
		label = self.labels[idx]
		return image, label

In [5]:
train_transform = v2.Compose([
	v2.ToImage(),
	v2.Resize(160),
	v2.RandomCrop(128),
	v2.RandomHorizontalFlip(p=0.5),
	v2.RandomVerticalFlip(p=0.5),
	v2.ToDtype(torch.float32, scale=True),
])

test_transform = v2.Compose([
	v2.ToImage(),
	v2.Resize(128),
	v2.CenterCrop(128),
	v2.ToDtype(torch.float32, scale=True),
])

train_dataset = CustomDataset(root_dir='dataset/', transform=train_transform, train=True)
test_dataset = CustomDataset(root_dir='dataset/', transform=test_transform, train=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [6]:
model = timm.create_model('efficientnet_b0', pretrained=True)
model.classifier = nn.Sequential(
	nn.Linear(model.classifier.in_features, 1),
	nn.Sigmoid()
)
model = model.to(device)

In [7]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [8]:
def correct_count(outputs, labels):
	preds = (outputs > 0.5).float()
	return (preds == labels).float().sum()

In [9]:
for epoch in range(EPOCHS):
	model.train()
	train_loss = 0.0
	train_correct = 0
	total_items = 0
	tqdm_bar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{EPOCHS}', total=len(train_loader))

	for inputs, labels in tqdm_bar:
		inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)
		optimizer.zero_grad()

		outputs = model(inputs)
		loss = criterion(outputs, labels)
		loss.backward()
		optimizer.step()

		train_loss += loss.item() * inputs.size(0)
		train_correct += correct_count(outputs, labels)
		total_items += inputs.size(0)

		avg_loss = train_loss / total_items
		tqdm_bar.set_postfix(loss=avg_loss)

	lr_scheduler.step()

	model.eval()
	val_loss = 0.0
	val_correct = 0
	with torch.no_grad():
		for inputs, labels in test_loader:
			inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			val_loss += loss.item() * inputs.size(0)
			val_correct += correct_count(outputs, labels)

	train_loss /= total_items
	train_accuracy = train_correct / total_items
	val_loss /= len(test_loader.dataset)
	val_accuracy = val_correct / len(test_loader.dataset)

	print(f'Epoch {epoch+1}/{EPOCHS}. '
				f'Train Loss: {train_loss:.3f}, Train Acc: {train_accuracy:.2f}. '
				f'Val Loss: {val_loss:.3f}, Val Acc: {val_accuracy:.2f}')

Training Epoch 1/20: 100%|██████████| 62/62 [00:42<00:00,  1.45it/s, loss=0.273]


Epoch 1/20. Train Loss: 0.273, Train Acc: 0.88. Val Loss: 0.141, Val Acc: 0.97


Training Epoch 2/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.135]


Epoch 2/20. Train Loss: 0.135, Train Acc: 0.95. Val Loss: 0.139, Val Acc: 0.96


Training Epoch 3/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.12]  


Epoch 3/20. Train Loss: 0.120, Train Acc: 0.96. Val Loss: 0.107, Val Acc: 0.97


Training Epoch 4/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0973]


Epoch 4/20. Train Loss: 0.097, Train Acc: 0.97. Val Loss: 0.067, Val Acc: 0.99


Training Epoch 5/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0763]


Epoch 5/20. Train Loss: 0.076, Train Acc: 0.98. Val Loss: 0.075, Val Acc: 0.97


Training Epoch 6/20: 100%|██████████| 62/62 [00:42<00:00,  1.47it/s, loss=0.0734]


Epoch 6/20. Train Loss: 0.073, Train Acc: 0.98. Val Loss: 0.060, Val Acc: 0.98


Training Epoch 7/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0424]


Epoch 7/20. Train Loss: 0.042, Train Acc: 0.99. Val Loss: 0.043, Val Acc: 0.99


Training Epoch 8/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0317]


Epoch 8/20. Train Loss: 0.032, Train Acc: 0.99. Val Loss: 0.047, Val Acc: 0.99


Training Epoch 9/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0335]


Epoch 9/20. Train Loss: 0.033, Train Acc: 0.99. Val Loss: 0.023, Val Acc: 1.00


Training Epoch 10/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0368]


Epoch 10/20. Train Loss: 0.037, Train Acc: 0.99. Val Loss: 0.024, Val Acc: 0.99


Training Epoch 11/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0354]


Epoch 11/20. Train Loss: 0.035, Train Acc: 0.99. Val Loss: 0.016, Val Acc: 1.00


Training Epoch 12/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0164]


Epoch 12/20. Train Loss: 0.016, Train Acc: 0.99. Val Loss: 0.009, Val Acc: 1.00


Training Epoch 13/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0124]


Epoch 13/20. Train Loss: 0.012, Train Acc: 1.00. Val Loss: 0.009, Val Acc: 1.00


Training Epoch 14/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.017]  


Epoch 14/20. Train Loss: 0.017, Train Acc: 0.99. Val Loss: 0.013, Val Acc: 0.99


Training Epoch 15/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0107] 


Epoch 15/20. Train Loss: 0.011, Train Acc: 1.00. Val Loss: 0.011, Val Acc: 1.00


Training Epoch 16/20: 100%|██████████| 62/62 [00:41<00:00,  1.49it/s, loss=0.0129] 


Epoch 16/20. Train Loss: 0.013, Train Acc: 1.00. Val Loss: 0.017, Val Acc: 0.99


Training Epoch 17/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.00694]


Epoch 17/20. Train Loss: 0.007, Train Acc: 1.00. Val Loss: 0.013, Val Acc: 0.99


Training Epoch 18/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.00626]


Epoch 18/20. Train Loss: 0.006, Train Acc: 1.00. Val Loss: 0.012, Val Acc: 0.99


Training Epoch 19/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.00699]


Epoch 19/20. Train Loss: 0.007, Train Acc: 1.00. Val Loss: 0.012, Val Acc: 0.99


Training Epoch 20/20: 100%|██████████| 62/62 [00:41<00:00,  1.48it/s, loss=0.0125] 


Epoch 20/20. Train Loss: 0.012, Train Acc: 1.00. Val Loss: 0.009, Val Acc: 1.00


In [10]:
torch.save(model.state_dict(), 'model.pth')