In [1]:
from utils import *

In [2]:
config = Config().load(os.path.join("configs", "masked.json"))

In [3]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch

In [4]:
def trainModel(config, modelClass, datasetClass):
    dataset = datasetClass(config.dataset)
    model = modelClass(config.model)

    print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learningRate)

    auxiliary = nn.CrossEntropyLoss()
    objective = nn.MSELoss()

    train, test = torch.utils.data.random_split(dataset, [0.8, 0.2], generator=torch.Generator(device=device))
    train = DataLoader(train, batch_size=config.batchSize, shuffle=True, collate_fn=dataset.collate, generator=torch.Generator(device=device))
    test = DataLoader(test, batch_size=config.batchSize, shuffle=True, collate_fn=dataset.collate, generator=torch.Generator(device=device))

    testIter = itertoolsBetter(test)

    client = Client("127.0.0.1", 12954)

    try:
        for epoch in range(config.epochs):
            progress = 0
            for inputs, targets, info in train:
                model.train()
                optimizer.zero_grad()

                outputs, classification = model(inputs)
                loss = objective(outputs, targets) + auxiliary(classification, info) * 0.2

                trainLoss = loss.detach().item()
                trainDiff = objective(inputs, outputs).detach().item()
                trainClass = auxiliary(classification, info)

                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    model.eval()
                    inputs1, targets1, info1 = next(testIter)
                    outputs1, classification1 = model(inputs1)
                    loss1 = objective(outputs1, targets1)

                    testLoss = loss1.detach().item()
                    testDiff = objective(inputs1, outputs1).detach().item()
                    testClass = auxiliary(classification1, info1)

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)

                client.send("Train Diff", trainDiff)
                client.send("Test Diff", testDiff)

                client.send("Train Class", trainClass)
                client.send("Test Class", testClass)

                progress += 1

                print(f"\r{epoch + 1} | {progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss:.2f} | Test Loss: {testLoss:.2f}", end="")

    except KeyboardInterrupt:
        latestPath = os.path.join("checkpoints", "pretrain", "latest")
        if not os.path.exists(os.path.join("checkpoints", "pretrain", "latest")):
            os.mkdir(latestPath)

        stamp = datetime.now().strftime("%Y-%m-%d %H-%M")
        timePath = os.path.join("checkpoints", "pretrain", stamp)
        if not os.path.exists(timePath):
            os.mkdir(timePath)

        torch.save(model.state_dict(), os.path.join(latestPath, "checkpoint.pt"))
        torch.save(model.state_dict(), os.path.join(timePath, "checkpoint.pt"))
        config.save(os.path.join(latestPath, "config.json"))
        config.save(os.path.join(timePath, "config.json"))
        return model

In [5]:
trainModel(config, UNet, FontData)

Fonts serialized: 1821/3794google\fonts\ofl\kumarone\KumarOne-Regular.ttf execution context too long
Fonts serialized: 2335/3794google\fonts\ofl\notocoloremojicompattest\NotoColorEmojiCompatTest-Regular.ttf invalid pixel size
Fonts serialized: 3702/3794google\fonts\ofl\zcoolxiaowei\ZCOOLXiaoWei-Regular.ttf [Errno 22] Invalid argument: 'google\\bitmaps\\????? ?? al.bmp'
Images loaded: 266201/266219
Model has 9510028 parameters
3 | 208/978 | 21.268% |  Train Loss: 0.12 | Test Loss: 0.000

UNet(
  (input): Sequential(
    (0): Linear(in_features=1, out_features=24, bias=True)
    (1): ReLU()
  )
  (ups): ModuleList(
    (0): Up(
      (conv): ConvBlock(
        (module): Sequential(
          (0): Conv2d(96, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
          (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
          (4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
          (6): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
          (7): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (8): ReLU()
          (9): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=same)
          (10): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (11): ReLU()
        )
      )