In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import tqdm
import copy
import torchvision.models as models
from icecream import ic

In [None]:
class MyDataManager(Dataset):
    def __init__(self, root, transform=None):
        super(MyDataManager, self).__init__()
        with open(root, "r") as f:
            self.image_list = f.read().splitlines()
        self.transform = transform
    
    def __getitem__(self, idx):
        img = Image.open(self.image_list[idx])
        label_split = self.image_list[idx].split("/")[-2]
        label = int(label_split.split("_")[-2])
        if self.transform is not None:
            img = self.transform(img)
        return transforms.ToTensor()(img), label

    def __len__(self):
        return len(self.image_list)

In [None]:
train_path = "train.txt"
val_path = "val.txt"
judge_path = "real"

train_data = MyDataManager(train_path)
train_dataLoader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
val_data = MyDataManager(val_path)
val_dataLoader = DataLoader(val_data, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
def net_make(weight):
    net = models.resnet50(weights=weight)
    net.fc = nn.Linear(2048, 1024)
    net.fc = nn.Sequential(
        net.fc,
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(1024, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, 1),
        nn.Sigmoid()
    )
    ic(net)
    return net

In [None]:
save_path = "ResNet_weight.pth"
text_name = "ResNet.txt"

epochs = 3
device = "cuda" if torch.cuda.is_available() else "cpu"

net = net_make(models.ResNet50_Weights.IMAGENET1K_V2).to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.0001)

best_acc = 0.0

In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    accuracy_train = 0.0
    total_num = 0.0
    net.train()

    for i, (imgs, labels) in tqdm.tqdm(enumerate(train_dataLoader), total=len(train_dataLoader)):
        imgs = imgs.to(device)
        labels = labels.to(device=device, dtype=torch.float32)
        optimizer.zero_grad()
        outputs = net(imgs)
        loss = criterion(outputs, torch.reshape(labels, (labels.shape[0], 1)))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()*labels.shape[0]

        with torch.no_grad():
            outputs[torch.where(outputs >= 0.5)] = 1
            outputs[torch.where(outputs < 0.5)] = 0
            labels = torch.reshape(labels, (labels.shape[0], 1))
            accuracy_train += torch.sum(outputs == labels).item()
            total_num += labels.shape[0]

        if i % 10 == 9:
            ic(f"epoch:{epoch+1}, iter:{i+1}, loss:{running_loss/total_num:.5f}, accuracy(train) = {accuracy_train/total_num*100:.3f}%")
    
    net.eval()
    total_num = 0.0
    accuracy_val = 0.0
    for imgs, labels in val_dataLoader:
        imgs = imgs.to(device)
        labels = labels.to(device=device, dtype=torch.float32)
        with torch.no_grad():
            outputs = net(imgs)
            outputs[torch.where(outputs >= 0.5)] = 1
            outputs[torch.where(outputs < 0.5)] = 0
            labels = torch.reshape(labels, (labels.shape[0], 1))
            accuracy_val += torch.sum(outputs == labels).item()
            total_num += labels.shape[0]
    epoch_acc = accuracy_val/total_num
    ic(f"accuracy(valid) = {epoch_acc*100:.3f}%")
    with open(text_name, mode="a") as f:
        f.write(f"accuracy(valid) = {epoch_acc*100:.3f}%\n")
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model_wts = copy.deepcopy(net.state_dict())
        torch.save(best_model_wts, save_path)
ic(f"best_accuracy = {best_acc*100:.3f}%")

In [None]:
test_path = "test.txt"
weight = "ResNet_weight.pth"

test_data = MyDataManager(test_path)
test_dataLoader = DataLoader(test_data, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
net = net_make(weight).to(device)
net.load_state_dict(torch.load(weight, weights_only=True))
net.eval()

total_num = 0
accuracy_test = 0.0
for imgs, labels in tqdm.tqdm(test_dataLoader):
    imgs = imgs.to(device)
    labels = labels.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        outputs = net(imgs)
        outputs[torch.where(outputs >= 0.5)] = 1
        outputs[torch.where(outputs < 0.5)] = 0
        labels = torch.reshape(labels, (labels.shape[0], 1))
        accuracy_test += torch.sum(outputs == labels).item()
        total_num += labels.shape[0]
ic(f"accuracy(test) = {accuracy_test/total_num*100:.3f}%")