In [1]:
import torch
import pandas

In [2]:
from typing import Any
from torch.utils.data import Dataset

class PandasDataset(Dataset):
    def __init__(self, df) -> None:
        self.dataframe = df

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index) -> Any:
        series = self.dataframe.iloc[index]
        # Convert the pandas Series to a numpy array
        array = series.values
        # Convert the numpy array to a tensor
        tensor = torch.from_numpy(array)
        return tensor
    
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        descript = torch.tensor(row['descriptions'])
        img_embed = torch.tensor(row['img_emb'])
        country_encoding = torch.tensor(row['country_enc'])
        cell_target = torch.tensor(row['cell_target'])
        coordinate_target = torch.tensor(row['coordinate_target'])

        # return {
        #     'descript': descript,
        #     'img_embed': img_embed,
        #     'country_encoding': country_encoding,
        #     'cell_target': cell_target,
        #     'coordinate_target': coordinate_target
        # }

        return descript, img_embed, country_encoding, cell_target, coordinate_target

In [13]:
# Description Emb., Img Emb., Country Hot Encoding, Cell Target, Coordinate Target (Lat, Lon)
n = 124
descripts = [torch.randn(716).numpy() for i in range(n)]
img_embeds = [torch.randn(716).numpy() for i in range(n)]

country_encodings = []
for i in range(n):
    enc = torch.zeros(221).numpy()
    enc[torch.randint(0, 221, (1, ))[0]] = 1
    country_encodings.append(enc)

cell_targets = [torch.randint(0, 10000, (1, ))[0].numpy() for i in range(n)]

coordinate_targets = []
for i in range(n):
    enc = torch.randn(2).numpy()
    enc[0] *= 180
    enc[1] *= 90
    coordinate_targets.append(enc)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)

In [4]:
df = pandas.DataFrame(list(zip(descripts, img_embeds, country_encodings, cell_targets, coordinate_targets)),
               columns =['descriptions', 'img_emb', "country_enc", "cell_target", "coordinate_target"])


In [5]:
pd_dataset = CustomDataset(df)

In [6]:
for batch in torch.utils.data.DataLoader(pd_dataset):
    # print(batch)
    print(type(batch[0]))
    break

<class 'torch.Tensor'>


In [7]:
from pathlib import Path
import sys
sys.path.append(str(Path.cwd().parent))

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from model.attention_module import AttentionWeightedAggregation, LinearAttention
# from model.backbone import LatLongHead, StreetCLIP, TextEncoder #TODO remove unused imports
from model.country_prediction import CountryClassifier
from model.head.geolocation_head import MLPCentroid, HybridHeadCentroid

from model.attention_module import get_pseudo_label_loss

from datasets import load_dataset

clue_embeddings = pd.read_pickle('../data/guidebook_roberta_base_ch_in.pkl')
clue_embeddings = torch.tensor(list(clue_embeddings.values()))
clues = load_dataset("gips-mai/all_clues_enc")['train'][:len(clue_embeddings)]
clues['encoding'] = None

  clue_embeddings = torch.tensor(list(clue_embeddings.values()))


In [21]:


### HYPER PARAMETERS ###
lr = 0.001
alpha = 0.75
use_tanh = True
scale_tanh = 1.2
### HYPER PARAMETERS ###

device = 'cpu'

clue_embedding_size:int = 768
text_embedding_size:int = 512
clip_embedding_size:int = 716

country_encoding = pd.read_csv('../data/encodings.csv')

attention_aggregation = AttentionWeightedAggregation(temperature=0.01) #TODO definde temperature
linear_attention = LinearAttention(attn_input_img_size=clip_embedding_size, text_features_size=clue_embedding_size, hidden_layer_size_0=1024, hidden_layer_size_1=1024) #TODO hidden layer size
country_classifier = CountryClassifier(clue_embedding_size=clue_embedding_size, image_embedding_size=clip_embedding_size, alpha=alpha)

previous_stage_output = text_embedding_size+clip_embedding_size+clue_embedding_size
geohead = MLPCentroid(initial_dim=previous_stage_output, hidden_dim=[previous_stage_output, 1024, 512])
hybrid_head_centroid = HybridHeadCentroid(final_dim=11398, quadtree_path='../data/quad_tree/quadtree_10_1000.csv', use_tanh=use_tanh, scale_tanh=scale_tanh)

optimizer = optim.Adam(list(country_classifier.parameters()) + list(geohead.parameters()))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5)

cell_loss = nn.CrossEntropyLoss()
coordinate_loss = nn.MSELoss()

# clues = load_dataset("gips-mai/all_clues_enc")['train'][:20]
# clue_embeddings = torch.tensor(clues['encoding'])


# clues = load_dataset("gips-mai/all_clues_enc")['train'][:2]

pseudo_label_loss = get_pseudo_label_loss(clues["country_one_hot_enc"])
# descriptions = load_dataset("gips-mai/enc_descr")
# data_loader = torch.utils.data.DataLoader(descriptions, batch_size=32, shuffle=True)

data_loader = torch.utils.data.DataLoader(pd_dataset, batch_size = 16, shuffle=False)

country_losses = []
geo_losses = []
for epoch in range(10):
    for batch in data_loader:
        descriptions, imgs, country_target, cell_target, coordinate_target = batch
        imgs, descriptions, country_target, cell_target, coordinate_target = imgs.to(device), descriptions.to(device), country_target.to(device), cell_target.to(device), coordinate_target.to(device)

        optimizer.zero_grad()

        attention = linear_attention.forward(img_embedding=imgs)
        weighted_aggregation = attention_aggregation.forward(clue_embeddings=clue_embeddings, attention=attention)

        country_loss = country_classifier.training_step(x=weighted_aggregation, target=country_target) # target: get the iso2 of actual country and then look at the one hot encoding
        country_losses.append(country_loss)

        # pseudo label loss
        current_pseudo_label_loss = pseudo_label_loss(country_target, attention)

        aux_attention_loss = alpha * current_pseudo_label_loss + (1-alpha) * country_loss

        aggregated_input = torch.cat([imgs, descriptions, weighted_aggregation], dim=1)

        intermediate = geohead.forward(aggregated_input)
        prediction = hybrid_head_centroid.forward(intermediate, cell_target)

        total_loss = cell_loss.apply(prediction['label'], cell_target) + \
                     coordinate_loss.apply(prediction['gps'], coordinate_target) + \
                     aux_attention_loss

        total_loss.backward()
        optimizer.step()
    
    scheduler.step()

    # Print the loss at each epoch
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}")
    print(f"Epoch {epoch+1}, Loss: {country_loss.item():.4f}")


torch.Size([16, 768])
torch.Size([768, 3817])
divisor 2931456
agg torch.Size([16])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 1484x1024)

In [19]:
x

NameError: name 'x' is not defined

In [16]:
len(country_target[0])

221

In [None]:
x_tensor

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0])

In [None]:
y = torch.tensor([[0,1], [0,1]])
y[0] = torch.tensor([1,1])

In [None]:
torch.tensor(list(clue_embeddings.values()))

  torch.tensor(list(clue_embeddings.values()))


tensor([[-0.0618,  0.1175, -0.0703,  ..., -0.0676,  0.0133, -0.0567],
        [-0.0505,  0.1200, -0.0689,  ..., -0.0360,  0.0539, -0.0521],
        [-0.0460,  0.1177, -0.0573,  ..., -0.0460,  0.0242, -0.0192],
        ...,
        [-0.0966,  0.1018, -0.0211,  ..., -0.0497, -0.0341, -0.0596],
        [-0.0636,  0.1438, -0.0114,  ..., -0.0722, -0.0099, -0.0941],
        [-0.0786,  0.1549,  0.0098,  ..., -0.0329,  0.0091, -0.0597]])