In [1]:
from utils import *

In [2]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.entropy = nn.CrossEntropyLoss()
        # self.entropy = nn.BCEWithLogitsLoss()
        self.temperature = temperature

    def forward(self, y1, y2, names):
        y1, y2 = nn.functional.normalize(y1, dim=-1), nn.functional.normalize(y2, dim=-1)
        logits = y1 @ y2.t()
        # labels = names.unsqueeze(0) == names.unsqueeze(-1)
        labels = torch.arange(len(y1))

        return self.entropy(logits / self.temperature, labels)


class ContrastiveLoss2(nn.Module):
    def __init__(self):
        super().__init__()
        self.entropy = nn.CrossEntropyLoss()

    def forward(self, y1, y2, names):
        y1, y2 = nn.functional.normalize(y1, dim=-1), nn.functional.normalize(y2, dim=-1)
        logits = y1 @ y2.t()
        logOdds = 2 * torch.atanh(logits)
        labels = torch.arange(len(y1))

        return self.entropy(logOdds, labels)


class Perplexity(nn.Module):
    def __init__(self, loss):
        super().__init__()
        self.loss = loss

    def forward(self, y1, y2, names):
        log = self.loss(y1, y2, names)
        return torch.exp(log)

In [4]:
def saveExperiment(imageModel, textModel, config, experimentName, start):
    latestPath = os.path.join("checkpoints", "finetune", "latest")
    if not os.path.exists(os.path.join("checkpoints", "finetune", "latest")):
        os.mkdir(latestPath)

    stamp = start.strftime("%Y-%m-%d %H-%M")
    timePath = os.path.join("checkpoints", "finetune", stamp)
    if not os.path.exists(timePath):
        os.mkdir(timePath)

    saveToPath(latestPath, imageModel, textModel, config, experimentName)
    saveToPath(timePath, imageModel, textModel, config, experimentName)


def saveToPath(path, imageModel, textModel, config, experimentName):
    if not os.path.exists(os.path.join(path, experimentName)):
        os.mkdir(os.path.join(path, experimentName))

    torch.save(imageModel.state_dict(), os.path.join(path, experimentName, "image.pt"))
    textModel.save_pretrained(os.path.join(path, experimentName, "text"))
    # torch.save(textModel.state_dict(), os.path.join(path, experimentName, "text.pt"))
    config.save(os.path.join(path, experimentName, "config.json"))

In [5]:
def trainModel(config, textModel, imageModel, dataset, experimentName, start, imageConfig):
    imageOptimizer = torch.optim.Adam(imageModel.parameters(), lr=config.imageLearningRate)
    textOptimizer = torch.optim.Adam(imageModel.parameters(), lr=config.textLearningRate)

    objective = ContrastiveLoss(temperature=0.03)
    # objective = ContrastiveLoss2()
    criterion = Perplexity(objective)

    train, test = dataset.split(dataset, batchSize=config.batchSize)

    testIter = itertoolsBetter(test)

    client = Client("127.0.0.1", 12945)

    testHistory = []

    try:
        for epoch in range(config.epochs):
            progress = 0
            averageTrainLoss = 0
            averageTestLoss = 0
            for images, targets, info, text in train:
                imageModel.train()
                textModel.train()
                imageOptimizer.zero_grad()
                textOptimizer.zero_grad()

                imageOutputs = imageModel(images)
                textOutputs = textModel(**text).pooler_output
                loss = objective(imageOutputs, textOutputs, info)
                trainPerplexity = criterion(imageOutputs, textOutputs, info)

                trainLoss = loss.detach().item()
                averageTrainLoss = (averageTrainLoss * progress + trainLoss) / (progress + 1)

                loss.backward()
                textOptimizer.step()
                imageOptimizer.step()

                with torch.no_grad():
                    imageModel.eval()
                    textModel.eval()
                    images1, targets1, info1, text1 = next(testIter)
                    imageOutputs1 = imageModel(images1)
                    textOutputs1 = textModel(**text1).pooler_output
                    loss1 = objective(imageOutputs1, textOutputs1, info1)
                    testPerplexity = criterion(imageOutputs1, textOutputs1, info1)

                    testLoss = loss1.detach().item()
                    averageTestLoss = (averageTestLoss * progress + testLoss) / (progress + 1)

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)
                client.send("Train Perplexity", trainPerplexity.detach().item())
                client.send("Test Perplexity", testPerplexity.detach().item())

                progress += 1

                progressString = f"\r{epoch + 1} | {progress}/{len(train)} | {(progress / len(train)) * 100:.3f}%"

                print(f"{progressString} |  Train Loss: {averageTrainLoss:.2f} | Test Loss: {averageTestLoss:.2f}",end="")

            print(f"\rEpoch {epoch + 1} | Train Loss: {averageTrainLoss:.2f} | Test Loss: {averageTestLoss:.2f}{' ' * 50}")

            if (np.array(testHistory) < averageTestLoss).sum() >= 2:
                raise KeyboardInterrupt
            testHistory.append(averageTestLoss)

    except KeyboardInterrupt:
        saveExperiment(imageModel, textModel, imageConfig, experimentName, start)

        client.socket.close()
        del client

        return imageModel, textModel

    client.socket.close()
    del client
    saveExperiment(imageModel, textModel, imageConfig, experimentName, start)
    return imageModel, textModel

In [6]:
queryConfig = Config().load(os.path.join("configs", "querying.json"))

In [None]:
imageModelNames = ["masked", "lower", "upper", "CLIP"]
textModels = [CLIPTextModel, BertModel]
textModelNames = ["openai/clip-vit-base-patch32", "bert-base-uncased"]

imageModel, imageConfig = UNet.load(os.path.join("checkpoints", "pretrain", "upper"))

imageConfig.dataset.directory = "dataset"
dataset = MyFontsData(imageConfig.dataset)

for t, textModelName in enumerate(textModelNames):
    dataset.setTokenizer(textModelName)
    textModel = textModels[t].from_pretrained(textModelName)
    cfg = AutoConfig.from_pretrained(textModelName)
    if hasattr(cfg, "hidden_size"):
        textDimension = cfg.hidden_size
    else:
        textDimension = cfg.projection_dim

    for i, imageModelName in enumerate(imageModelNames):
        if imageModelName == "CLIP":
            imageConfig.model["textProjection"] = textDimension
            imageModel = CLIPEmbedder(imageConfig.model)
        else:
            imageModel, imageConfig = UNet.load(os.path.join("checkpoints", "pretrain", "upper"))
            imageConfig.model["textProjection"] = textDimension
            currentDim = imageConfig.model.filters * imageConfig.model.expansion ** imageConfig.model.layers
            imageModel.classifier = nn.Linear(currentDim, textDimension)

        if "method" in imageConfig.dataset:
            dataset.method = imageConfig.dataset.method

        if imageModelName == "CLIP":
            dataset.method = "masked"

        experimentName = imageModelName + " " + textModelName.replace("/", "-")

        if hasattr(imageModel, "outputType"):
            imageModel.outputType = "pooled"

        print(f"\n{'=' * 28}\n{experimentName}\n{'=' * 28}")

        imageModel, textModel = trainModel(queryConfig, textModel, imageModel, dataset, experimentName, datetime.now(), imageConfig)

        saveExperiment(imageModel, textModel, imageConfig, experimentName, datetime.now())