In [1]:
from utils import *

In [2]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.entropy = nn.CrossEntropyLoss()
        self.temperature = temperature

    def forward(self, y1, y2):
        logits = y1 @ y2.t()
        labels = torch.arange(len(y1))

        return self.entropy(logits / self.temperature, labels)

In [4]:
def saveExperiment(imageModel, textModel, config, experimentName, start):
    latestPath = os.path.join("checkpoints", "finetune", "latest")
    if not os.path.exists(os.path.join("checkpoints", "finetune", "latest")):
        os.mkdir(latestPath)

    stamp = start.strftime("%Y-%m-%d %H-%M")
    timePath = os.path.join("checkpoints", "finetune", stamp)
    if not os.path.exists(timePath):
        os.mkdir(timePath)

    saveToPath(latestPath, imageModel, textModel, config, experimentName)
    saveToPath(timePath, imageModel, textModel, config, experimentName)


def saveToPath(path, imageModel, textModel, config, experimentName):
    if not os.path.exists(os.path.join(path, experimentName)):
        os.mkdir(os.path.join(path, experimentName))

    torch.save(imageModel.state_dict(), os.path.join(path, experimentName, "image.pt"))
    torch.save(textModel.state_dict(), os.path.join(path, experimentName, "text.pt"))
    config.save(os.path.join(path, experimentName, "config.json"))

In [5]:
def trainModel(config, textModel, imageModel, dataset, experimentName, start, imageConfig):
    imageOptimizer = torch.optim.Adam(imageModel.parameters(), lr=config.imageLearningRate)
    textOptimizer = torch.optim.Adam(imageModel.parameters(), lr=config.textLearningRate)

    objective = ContrastiveLoss(temperature=1)

    train, test = dataset.split(dataset, batchSize=config.batchSize)

    testIter = itertoolsBetter(test)

    client = Client("127.0.0.1", 12954)

    try:
        for epoch in range(config.epochs):
            progress = 0
            for images, targets, info, text in train:
                imageModel.train()
                textModel.train()
                imageOptimizer.zero_grad()
                textOptimizer.zero_grad()

                imageOutputs = imageModel(images)
                textOutputs = textModel(**text).pooler_output
                loss = objective(imageOutputs, textOutputs)

                trainLoss = loss.detach().item()

                loss.backward()
                textOptimizer.step()
                imageOptimizer.step()

                with torch.no_grad():
                    imageModel.eval()
                    textModel.eval()
                    images1, targets1, info1, text1 = next(testIter)
                    imageOutputs1 = imageModel(images1)
                    textOutputs1 = textModel(**text1).pooler_output
                    loss1 = objective(imageOutputs1, textOutputs1)

                    testLoss = loss1.detach().item()

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)

                progress += 1

                progressString = f"\r{epoch + 1} | {progress}/{len(train)} | {(progress / len(train)) * 100:.3f}%"

                print(f"{progressString} |  Train Loss: {trainLoss:.2f} | Test Loss: {testLoss:.2f}",end="")

    except KeyboardInterrupt:
        saveExperiment(imageModel, textModel, imageConfig, experimentName, start)

        return imageModel, textModel

    saveExperiment(imageModel, textModel, imageConfig, experimentName, start)
    return imageModel, textModel

In [6]:
queryConfig = Config().load(os.path.join("configs", "querying.json"))

In [None]:
imageModelNames = ["lower", "upper", "masked", "CLIP"]
textModels = [CLIPTextModel, BertModel]
textModelNames = ["openai/clip-vit-base-patch32", "bert-base-uncased"]

imageModel, imageConfig = UNet.load(os.path.join("checkpoints", "pretrain", "upper"))

imageConfig.dataset.directory = "google"
dataset = QueryData(imageConfig.dataset)

for t, textModelName in enumerate(textModelNames):
    dataset.setTokenizer(textModelName)
    textModel = textModels[t].from_pretrained(textModelName)
    cfg = AutoConfig.from_pretrained(textModelName)
    if hasattr(cfg, "hidden_size"):
        textDimension = cfg.hidden_size
    else:
        textDimension = cfg.projection_dim

    for i, imageModelName in enumerate(imageModelNames):
        if imageModelName == "CLIP":
            imageConfig.model["textProjection"] = textDimension
            imageModel = CLIPEmbedder(imageConfig.model)
        else:
            imageModel, imageConfig = UNet.load(os.path.join("checkpoints", "pretrain", "upper"))
            imageConfig.model["textProjection"] = textDimension
            currentDim = imageConfig.model.filters * imageConfig.model.expansion ** imageConfig.model.layers
            imageModel.classifier = nn.Linear(currentDim, textDimension)

        if "method" in imageConfig.dataset:
            dataset.method = imageConfig.dataset.method

        if imageModelName == "CLIP":
            dataset.method = "masked"

        experimentName = imageModelName + " " + textModelName.replace("/", "-")

        if hasattr(imageModel, "outputType"):
            imageModel.outputType = "pooled"

        print(f"\n{'=' * 28}\n{experimentName}\n{'=' * 28}")

        imageModel, textModel = trainModel(queryConfig, textModel, imageModel, dataset, experimentName, datetime.now(), imageConfig)

        saveExperiment(imageModel, textModel, queryConfig, experimentName, datetime.now())

Fonts serialized: 1821/3794google\fonts\ofl\kumarone\KumarOne-Regular.ttf execution context too long
Fonts serialized: 2335/3794google\fonts\ofl\notocoloremojicompattest\NotoColorEmojiCompatTest-Regular.ttf invalid pixel size
Fonts serialized: 3702/3794google\fonts\ofl\zcoolxiaowei\ZCOOLXiaoWei-Regular.ttf [Errno 22] Invalid argument: 'google\\bitmaps\\????? ?? al.bmp'
Images loaded: 266201/266219
Paths checked: 1301/3620google\fonts\ofl\notocoloremojicompattest\NotoColorEmojiCompatTest-Regular.ttf invalid pixel size
Paths checked: 2101/3620google\fonts\ofl\mplusrounded1c ['Rounded Mplus 1c', 'Rounded Mplus 1c Bold', 'Rounded Mplus 1c', 'Rounded Mplus 1c', 'Rounded Mplus 1c', 'Rounded Mplus 1c', 'Rounded Mplus 1c']
Paths checked: 3601/3620
98.49% of fonts have descriptions

lower openai-clip-vit-base-patch32
1 | 19/1607 | 1.182% |  Train Loss: 10.52 | Test Loss: 10.77