In [18]:
# This Demo is designed to show off the capabilites of the MTGNet models I trained.
# Copy an Images you would like the network to classify into the images folder.
# Images will need to be resized to 434x315 for this demo to function properly.
# You can use the variables below to configure how you would like your images to be classified.
include_dualcolor = True
LEARNING_RATE = .020711
MOMENTUM = .9
WEIGHT_DECAY = .001108
TYPE_MODEL_PATH = "models/dualcolor_type_model"
COLOR_MODEL_PATH = "models/dualcolor_color_model"
IMAGE_PATH = "Images"
COLOR_DICT = {0: 'B', 1: '', 2: 'G', 3: 'W', 4: 'R', 5: 'U', 6: 'RU', 7: 'RW', 8: 'BG', 9: 'GW', 10: 'BR', 11: 'UW', 12: 'BW', 13: 'BU', 14: 'GR', 15: 'GU'}
TYPE_DICT = {0: 'Construct', 1: 'Eldrazi', 2: 'Cat', 3: 'Bird', 4: 'Spirit', 5: 'Zombie', 6: 'Human', 7: 'Goblin', 8: 'Elemental', 9: 'Elf', 10: 'Beast', 11: 'Merfolk', 12: 'Vampire', 13: 'Insect', 14: 'Angel', 15: 'Dragon', 16: 'Giant'}

In [7]:
import torch
import os
import glob
import re
import numpy as np
import matplotlib.pyplot as plt
try:
    # For 2.7
    import cPickle as pickle
except:
    # For 3.x
    import pickle


def restore(net, save_file):
    """Restores the weights from a saved file
    This does more than the simple Pytorch restore. It checks that the names
    of variables match, and if they don't doesn't throw a fit. It is similar
    to how Caffe acts. This is especially useful if you decide to change your
    network architecture but don't want to retrain from scratch.
    Args:
        net(torch.nn.Module): The net to restore
        save_file(str): The file path
    """

    net_state_dict = net.state_dict()
    restore_state_dict = torch.load(save_file)

    restored_var_names = set()

    print('Restoring:')
    for var_name in restore_state_dict.keys():
        if var_name in net_state_dict:
            var_size = net_state_dict[var_name].size()
            restore_size = restore_state_dict[var_name].size()
            if var_size != restore_size:
                print('Shape mismatch for var', var_name, 'expected', var_size, 'got', restore_size)
            else:
                if isinstance(net_state_dict[var_name], torch.nn.Parameter):
                    # backwards compatibility for serialized parameters
                    net_state_dict[var_name] = restore_state_dict[var_name].data
                try:
                    net_state_dict[var_name].copy_(restore_state_dict[var_name])
                    print(str(var_name) + ' -> \t' + str(var_size) + ' = ' + str(int(np.prod(var_size) * 4 / 10**6)) + 'MB')
                    restored_var_names.add(var_name)
                except Exception as ex:
                    print('While copying the parameter named {}, whose dimensions in the model are'
                          ' {} and whose dimensions in the checkpoint are {}, ...'.format(
                              var_name, var_size, restore_size))
                    raise ex

    ignored_var_names = sorted(list(set(restore_state_dict.keys()) - restored_var_names))
    unset_var_names = sorted(list(set(net_state_dict.keys()) - restored_var_names))
    print('')
    if len(ignored_var_names) == 0:
        print('Restored all variables')
    else:
        print('Did not restore:\n\t' + '\n\t'.join(ignored_var_names))
    if len(unset_var_names) == 0:
        print('No new variables')
    else:
        print('Initialized but did not modify:\n\t' + '\n\t'.join(unset_var_names))

    print('Restored %s' % save_file)


def restore_latest(net, folder):
    """Restores the most recent weights in a folder
    Args:
        net(torch.nn.module): The net to restore
        folder(str): The folder path
    Returns:
        int: Attempts to parse the epoch from the state and returns it if possible. Otherwise returns 0.
    """

    checkpoints = sorted(glob.glob(folder + '/*.pt'), key=os.path.getmtime)
    start_it = 0
    if len(checkpoints) > 0:
        restore(net, checkpoints[-1])
        try:
            start_it = int(re.findall(r'\d+', checkpoints[-1])[-1])
        except:
            pass
    return start_it

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MTGNetDualcolorColor(nn.Module):
    def __init__(self):
        super(MTGNetDualcolorColor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 7, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.conv4_bn = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, stride=2, padding=1)
        self.conv5_bn = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512, 256)
        self.lin1_dropout = nn.Dropout(.5)
        self.fc2 = nn.Linear(256, 16)
        self.accuracy = None

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv3(x)
        x = self.conv3_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv4(x)
        x = self.conv4_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv5(x)
        x = self.conv5_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.lin1_dropout(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, prediction, label, reduction='mean'):
        loss_val = F.cross_entropy(prediction, label.squeeze(), reduction=reduction)
        return loss_val

    def load_last_model(self, dir_path):
        return restore_latest(self, dir_path)


class MTGNetType(nn.Module):
    def __init__(self):
        super(MTGNetType, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 7, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.conv4_bn = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, stride=2, padding=1)
        self.conv5_bn = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512, 256)
        self.lin1_dropout = nn.Dropout(.5)
        self.fc2 = nn.Linear(256, 17)
        self.accuracy = None

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv3(x)
        x = self.conv3_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv4(x)
        x = self.conv4_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv5(x)
        x = self.conv5_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.lin1_dropout(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, prediction, label, reduction='mean'):
        loss_val = F.cross_entropy(prediction, label.squeeze(), reduction=reduction)
        return loss_val

    def load_last_model(self, dir_path):
        return restore_latest(self, dir_path)


class MTGNetMonocolorColor(nn.Module):
    def __init__(self):
        super(MTGNetMonocolorColor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 7, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.conv4_bn = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, 3, stride=2, padding=1)
        self.conv5_bn = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512, 256)
        self.lin1_dropout = nn.Dropout(.5)
        self.fc2 = nn.Linear(256, 6)
        self.accuracy = None

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.conv3(x)
        x = self.conv3_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv4(x)
        x = self.conv4_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = self.lin1_dropout(x)
        x = self.conv5(x)
        x = self.conv5_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.lin1_dropout(x)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, prediction, label, reduction='mean'):
        loss_val = F.cross_entropy(prediction, label.squeeze(), reduction=reduction)
        return loss_val

    def load_last_model(self, dir_path):
        return restore_latest(self, dir_path)

In [28]:
import torch.optim as optim
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
type_model = MTGNetType().to(device)
optimizer = optim.SGD(type_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
start_epoch = type_model.load_last_model(TYPE_MODEL_PATH)
if include_dualcolor:
  color_model = MTGNetDualcolorColor().to(device)
else:
  color_model = MTGNetMonocolorColor().to(device)
optimizer = optim.SGD(color_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
start_epoch = color_model.load_last_model(COLOR_MODEL_PATH)
for im_file in os.listdir(IMAGE_PATH):
  if not os.path.isdir(im_file):
    image = Image.open(IMAGE_PATH + "/" + im_file)
    print("Image: " + im_file)
    image = image.convert('RGB')
    image = np.asarray(image).transpose(-1, 0, 1)
    image = image/255
    image = torch.from_numpy(np.asarray(image))
    image = image.float()
    image = image[None, :, :]
    image.to(device)
    type = int(type_model(image).max(1)[1])
    color = int(color_model(image).max(1)[1])
    print("Color: " + COLOR_DICT[color])
    print("Type: " + TYPE_DICT[type])

Image: 0a8b6f90-c609-4ffd-9136-9fd7a833ebb3.png
Color: GW
Type: Construct
Image: 0a7f0d35-b91e-461a-822c-7ae2798ca51a.png
Color: RU
Type: Insect
Image: 0a8e78b3-3232-4d48-9d6c-540951a0330e.png
Color: GR
Type: Vampire
Image: 0a7b01ce-2729-4b70-ae0a-b9a1007be78f.png
Color: R
Type: Construct
Image: 0a9c4c63-402e-489e-ab0d-1c98309b010a.png
Color: GR
Type: Insect
Image: 0a7c9678-dea7-4219-bac0-9e1cef531f54.png
Color: GU
Type: Eldrazi
