In [1]:
import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchinfo import summary
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {} device".format(device))

In [None]:
class CustomModel(nn.Module):
    def __init__(self, num_classes, use_pretrained=False):
        super(CustomModel, self).__init__()
        self.res = models.efficientnet_v2_s(pretrained=use_pretrained)
        num_ftrs = self.res.classifier[1].in_features
        self.res.classifier[1] = nn.Linear(num_ftrs, num_classes)
        
    def forward(self, x):
        x = torch.sigmoid(self.res(x))
        return x


num_classes = 28
net = CustomModel(num_classes).to(device)

net.load_state_dict(torch.load('data/effnet_s_BCE.pth')["state_dict"])
net.eval()

In [None]:
size = 480
Transform = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

path = r'D:\photo tta'

List = ['Amber', 'Ayaka', 'Barbara', 'Beidou', 'Eula', 'Fishl', 'Ganyu', 'Hu Tao', 
        'Jean', 'Keqing', 'Kokomi', 'Kuki Shinobu', 'Lisa', 'Lumine', 'Mona', 
        'Ninguang', 'Noelle', 'Raiden', 'Rosaria', 'Sara', 'Shenhe', 'Sucrose', 
        'Xiangling', 'Yae', 'Yanfei', 'Yelan', 'Yoimiya', 'Yun Jin']

for im in os.listdir(path):
    with torch.no_grad():
        try:
            image = Image.open(os.path.join(path, im)).convert('RGB')
            T = Transform(image)
            T = T[None, :, :, :].to(device)
            x = net(T)
            x = torch.topk(x, 3)
            values = x.values[0, :].to('cpu')
            indices = x.indices[0, :].to('cpu')
            dirs = []

            if values[0].item()/1.5 > values[1].item():
                dirs.append(indices[0].item())
            else:
                if values[0].item()/1.5 > values[2].item():
                    dirs.append(indices[0].item())
                    dirs.append(indices[1].item())
                else:
                    dirs.append(indices[0].item())
                    dirs.append(indices[1].item())
                    dirs.append(indices[2].item())

            dirs.sort()
            newdir = ''
            for i in dirs:
                newdir = f'{newdir} {List[i]}'
            newdir = os.path.join(path, newdir[1:])
            if not os.path.isdir(newdir):
                os.mkdir(newdir)
            os.rename(os.path.join(path, im), os.path.join(newdir, im))
        except: t = 0