In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = 'dataset'

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])

train_dataset = datasets.ImageFolder(root=os.path.join(data_dir), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)          
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)  
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  
        self.fc1 = nn.Linear(in_features=64 * 32 * 32, out_features=512)  
        self.fc2 = nn.Linear(in_features=512, out_features=num_classes)  

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  
        x = self.pool(torch.relu(self.conv2(x)))  
        x = x.view(-1, 64 * 32 * 32)
        x = torch.relu(self.fc1(x))  
        x = self.fc2(x)  
        return x
        
num_classes = len(train_dataset.classes)
model = SimpleCNN(num_classes).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    print(f'Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.2f}%')

In [8]:
def predict_image(image_path, model, transform, device):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)
    image = image.to(device)
    model.eval()
    
    with torch.no_grad():
        outputs = model(image)
        _, predicted_class = torch.max(outputs, 1)
    class_label = train_dataset.classes[predicted_class.item()]
    
    return class_label


In [9]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train(model, train_loader, criterion, optimizer, device)

torch.save(model.state_dict(), 'genshin_detector_cnn_model.pth')

print("Training complete and model saved.")

Epoch 1/10
Training Loss: 0.0164, Training Accuracy: 100.00%
Epoch 2/10
Training Loss: 0.0146, Training Accuracy: 99.80%
Epoch 3/10
Training Loss: 0.0096, Training Accuracy: 100.00%
Epoch 4/10
Training Loss: 0.0049, Training Accuracy: 100.00%
Epoch 5/10
Training Loss: 0.0028, Training Accuracy: 100.00%
Epoch 6/10
Training Loss: 0.0022, Training Accuracy: 100.00%
Epoch 7/10
Training Loss: 0.0019, Training Accuracy: 100.00%
Epoch 8/10
Training Loss: 0.0016, Training Accuracy: 100.00%
Epoch 9/10
Training Loss: 0.0013, Training Accuracy: 100.00%
Epoch 10/10
Training Loss: 0.0012, Training Accuracy: 100.00%
Training complete and model saved.


In [11]:
image_path = 'Hu_Tao_Test.jpg'  
predicted_class = predict_image(image_path, model, transform, device)
print(f'The predicted class for the image is: {predicted_class}')

image_path = 'Kokomi_Test.jpg'  
predicted_class = predict_image(image_path, model, transform, device)
print(f'The predicted class for the image is: {predicted_class}')

# Interesting how it thinks Ryu (red, black, and shades of brown) is Hu Tao. 
image_path = 'Ryu_Test.jpg'
predicted_class = predict_image(image_path, model, transform, device)
print(f'The predicted class for the image is: {predicted_class}')

# And X is Kokomi (blue, Da ba dee da ba die)
image_path = 'X_Test.jpg'
predicted_class = predict_image(image_path, model, transform, device)
print(f'The predicted class for the image is: {predicted_class}')

The predicted class for the image is: Hu Tao
The predicted class for the image is: Kokomi
The predicted class for the image is: Hu Tao
The predicted class for the image is: Kokomi
