In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import tqdm
import copy
import torchvision.models as models
import io
import glob

In [None]:
class MyDataManager(Dataset):
    def __init__(self, root, transform=None):
        super(MyDataManager, self).__init__()
        with open(root, "r") as f:
            self.image_list = f.read().splitlines()
        self.transform = transform
    
    def __getitem__(self, idx):
        img = Image.open(self.image_list[idx])
        label_split = self.image_list[idx].split("/")[-2]
        label = int(label_split.split("_")[-2])
        if self.transform is not None:
            img = self.transform(img)
        return transforms.ToTensor()(img), label

    def __len__(self):
        return len(self.image_list)

In [None]:
class JPEGCompressionTransform:
    def __init__(self, quality=55):
        self.quality = quality
    
    def __call__(self, img):
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=self.quality)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        return compressed_img

In [None]:
def net_make(weight):
    net = models.resnet50(weight=weight)
    net.fc = torch.Linear(2048, 1024)
    net.fc = nn.Sequential(
        net.fc,
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(1024, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, 1),
        nn.Sigmoid(),
    )
    print(net)
    return net

In [None]:
test_path = "test.txt"
weight = "ResNet_weight.pth"

transform = transforms.Compose([JPEGCompressionTransform(quality=95)])
test_data = MyDataManager(test_path, transform=transform)
test_dataLoader = DataLoader(test_data, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
net = net_make(None).to(device)
net.load_state_dict(torch.load(weight, weight_only=True))
net.eval()

total_num = 0
accuracy_test = 0.0
for imgs, labels in tqdm.tqdm(test_dataLoader):
    imgs = imgs.to(device)
    labels = labels.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        outputs = net(imgs)
        outputs[torch.where(outputs >= 0.5)] = 1
        outputs[torch.where(outputs < 0.5)] = 0
        labels = torch.reshape(labels, (labels.shape[0], 1))
        accuracy_test += torch.sum(outputs == labels).item()
        total_num += imgs.shape[0]
print(f"accuracy(test) = {accuracy_test/total_num*100:.3f}%")