In [1]:
from model import *
from data import *
from vis import *

In [2]:
config = Config().load(os.path.join("configs", "config.json"))

In [3]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch

In [4]:
class MarginMSE(nn.Module):
    def __init__(self, beta=0.2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.beta = beta

    def forward(self, x, yPred, yTrue):
        return self.mse(yPred, yTrue) - self.beta * self.mse(x, yPred)

In [5]:
def trainModel(config, modelClass, datasetClass):
    dataset = datasetClass(config.dataset)
    model = modelClass(config.model)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learningRate)
    # objective = MarginMSE(beta=0.2)
    objective = nn.MSELoss()

    train, test = torch.utils.data.random_split(dataset, [0.8, 0.2])
    train = DataLoader(train, batch_size=config.batchSize, shuffle=True)
    test = DataLoader(test, batch_size=config.batchSize, shuffle=True)

    testIter = itertoolsBetter(test)

    client = Client("127.0.0.1", 12954)

    try:
        for epoch in range(config.epochs):
            progress = 0
            for inputs, targets in train:
                model.train()
                optimizer.zero_grad()

                outputs = model(inputs)
                loss = objective(outputs, targets)

                trainLoss = loss.detach().item()
                trainDiff = objective(inputs, outputs).detach().item()

                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    model.eval()
                    inputs1, targets1 = next(testIter)
                    outputs1 = model(inputs1)
                    loss1 = objective(outputs1, targets1)

                    testLoss = loss1.detach().item()
                    testDiff = objective(inputs1, outputs1).detach().item()

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)

                client.send("Train Diff", trainDiff)
                client.send("Test Diff", testDiff)

                progress += 1

                print(f"\r{epoch + 1} | {progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss:.2f} | Test Loss: {testLoss:.2f} | Memory: {torch.cuda.memory_allocated()}", end="")

    except KeyboardInterrupt:
        latestPath = os.path.join("checkpoints", "latest")
        if not os.path.exists(os.path.join("checkpoints", "latest")):
            os.mkdir(latestPath)

        stamp = datetime.now().strftime("%Y-%m-%d %H-%M")
        timePath = os.path.join("checkpoints", stamp)
        if not os.path.exists(timePath):
            os.mkdir(timePath)

        torch.save(model, os.path.join(latestPath, "checkpoint.pt"))
        torch.save(model, os.path.join(timePath, "checkpoint.pt"))
        config.save(os.path.join(latestPath, "config.json"))
        config.save(os.path.join(timePath, "config.json"))
        return model

In [None]:
trainModel(config, UNet, FontData)

data\fonts\coure.fon invalid pixel size
data\fonts\dosapp.fon invalid pixel size
data\fonts\modern.fon unknown file format
data\fonts\roman.fon unknown file format
data\fonts\script.fon unknown file format
data\fonts\serife.fon invalid pixel size
data\fonts\smalle.fon invalid pixel size
data\fonts\sserife.fon invalid pixel size
