In [None]:
from utils import *

In [None]:
def itertoolsBetter(dataIter):
    while True:
        for batch in dataIter:
            yield batch

In [None]:
class ContrastiveLoss(nn.Module):
	def __init__(self, temperature=0.1):
		self.entropy = nn.CrossEntropyLoss()
		self.temperature = temperature

	def forward(self, y1, y2):
		logits = y1 @ y2.t()
		labels = torch.arange(len(y1))

		return self.entropy(logits / self.temperature, labels)

In [None]:
def saveToPath(path, imageModel, textModel, config, experimentName):
	if not os.path.exists(os.path.join(path, experimentName)):
		os.mkdir(os.path.join(path, experimentName))
	
	torch.save(imageModel.state_dict(), os.path.join(path, experimentName, "image.pt"))
	torch.save(textModel.state_dict(), os.path.join(path, experimentName, "text.pt"))
	config.save(os.path.join(path, experimentName, "config.json"))

In [None]:
def trainModel(config, textModelClass, imageModel, datasetClass, experimentName, start):
    dataset = datasetClass(config.dataset)
	textModel = textModelClass(config.textModel)

    print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learningRate)

    objective = ContrastiveLoss()

    train, test = torch.utils.data.random_split(dataset, [0.8, 0.2], generator=torch.Generator(device=device))
    train = DataLoader(train, batch_size=config.batchSize, shuffle=True, collate_fn=dataset.collate, generator=torch.Generator(device=device))
    test = DataLoader(test, batch_size=config.batchSize, shuffle=True, collate_fn=dataset.collate, generator=torch.Generator(device=device))

    testIter = itertoolsBetter(test)

    client = Client("127.0.0.1", 12954)

    try:
        for epoch in range(config.epochs):
            progress = 0
            for images, targets, info, text in train:
                model.train()
                optimizer.zero_grad()

                imageOutputs = imageModel(images)
				textOutputs = textModel(text)
                loss = objective(imageOutputs, textOutputs)

                trainLoss = loss.detach().item()

                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    model.eval()
                    images1, targets1, info1, text1 = next(testIter)
                    imageOutputs1 = imageModel(images1)
					textOutputs1 = textModel(text1)
                    loss1 = objective(imageOutputs1, textOutputs1)

                    testLoss = loss1.detach().item()

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)

                progress += 1

                print(f"\r{epoch + 1} | {progress}/{len(train)} | {(progress / len(train)) * 100:.3f}% |  Train Loss: {trainLoss:.2f} | Test Loss: {testLoss:.2f}", end="")

    except KeyboardInterrupt:
        latestPath = os.path.join("checkpoints", "finetune", "latest")
        if not os.path.exists(os.path.join("checkpoints", "finetune", "latest")):
            os.mkdir(latestPath)

        stamp = start.strftime("%Y-%m-%d %H-%M")
        timePath = os.path.join("checkpoints", "finetune", stamp)
        if not os.path.exists(timePath):
            os.mkdir(timePath)

		saveToPath(latestPath, imageModel, textModel, config, experimentName)
		saveToPath(timePath, imageModel, textModel, config, experimentName)

        return imageModel, textModel