In [14]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
import pandas

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
train_data = load_dataset("gips-mai/osv5m_ann", split='train')

Downloading readme: 100%|██████████| 2.03k/2.03k [00:00<00:00, 841kB/s]
Downloading data: 100%|██████████| 305M/305M [22:29<00:00, 226kB/s] 
Downloading data:  76%|███████▌  | 231M/305M [16:42<05:25, 228kB/s] Error while downloading from https://huggingface.co/datasets/gips-mai/osv5m_ann/resolve/e3d2aa7fdcb0dbaf996f2dc28b80f1a363d544ef/data/01-00001-of-00002.parquet: HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Read timed out.
Trying to resume download...
Trying to resume download...
Downloading data:  79%|███████▉  | 241M/305M [21:29<12:21, 85.9kB/s]

KeyboardInterrupt: 

Downloading data:  79%|███████▉  | 241M/305M [21:40<12:21, 85.9kB/s]

In [3]:
clues = load_dataset("gips-mai/all_clues_enc", split='train')
len_countries = len(clues['country_one_hot_enc'][6][0])
len_countries

221

In [4]:
one_hot_encoding = [torch.zeros(len_countries) for i in range(len(clues))]

for i, c in enumerate(clues['country_one_hot_enc']):
    for enc in c:
        if len(enc) > 0:
            one_hot_encoding[i] += torch.Tensor(enc)
        else:
            print(c)
    one_hot_encoding[i] = list(one_hot_encoding[i].numpy().astype('int'))


[[]]


In [5]:
csv = clues.to_pandas()
csv['country_one_hot_enc'] = one_hot_encoding
csv

In [7]:
#csv.to_csv('../data/all_clues_batchable.csv', index=False)

In [7]:
from typing import Any
from torch.utils.data import Dataset

class PandasDataset(Dataset):
    def __init__(self, df) -> None:
        self.dataframe = df

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index) -> Any:
        return list(self.dataframe.iloc[index])

In [25]:
csv_loader = DataLoader(csv, batch_size=2, shuffle=False)

In [15]:
# Description Emb., Img Emb., Country Hot Encoding, Cell Target, Coordinate Target (Lat, Lon)
n = 124
descripts = [torch.randn(716, dtype=torch.double).numpy() for i in range(n)]
img_embeds = [torch.randn(716, dtype=torch.double).numpy() for i in range(n)]

country_encodings = []
for i in range(n):
    enc = torch.zeros(221, dtype=torch.double).numpy()
    enc[torch.randint(0, 221, (1, ))[0]] = 1
    country_encodings.append(enc)

cell_targets = [torch.randint(0, 10000, (1, ))[0].numpy() for i in range(n)]

coordinate_targets = []
for i in range(n):
    enc = torch.randn(2, dtype=torch.double).numpy()
    enc[0] *= 180
    enc[1] *= 90
    coordinate_targets.append(enc)


In [17]:
df = pandas.DataFrame(list(zip(descripts, img_embeds, country_encodings, cell_targets, coordinate_targets)),
               columns =['descriptions', 'img_emb', "country_enc", "cell_target", "coordinate_target"])
pd_dataset = PandasDataset(df)

In [28]:
import sys
import os
sys.path.append(os.path.join(".", ".."))

In [29]:
from pathlib import Path
import sys
sys.path.append(Path.cwd().parent)

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

import model.head.geolocation_head
from model.head.geolocation_head import MLPCentroid, HybridHeadCentroid
from model.attention_module import AttentionWeightedAggregation, LinearAttention
from model.country_prediction import CountryClassifier

from datasets import load_dataset
import os
print("done")

done


In [31]:

### HYPER PARAMETERS ###
lr = 0.001
alpha = 0.75
use_tanh=True
scale_tanh=1.2
### HYPER PARAMETERS ###

device = 'cuda'

clue_embedding_size:int = 512
text_embedding_size:int = 716
clip_embedding_size:int = 716
final_dim = 11399 # quadtree len

previous_stage_output = clip_embedding_size +  text_embedding_size #+clip_embedding_size+clue_embedding_size
geohead_mid_network = model.head.geolocation_head.MLPCentroid(initial_dim=previous_stage_output,
                                                              hidden_dim=[previous_stage_output, 1024, 512],
                                                              final_dim=final_dim,
                                                              activation=torch.nn.GELU,
                                                              norm=torch.nn.GroupNorm)

quad_tree_path = os.path.join(".", "..", "data", "quad_tree", "quadtree_10_1000.csv")

hybrid_head_centroid = HybridHeadCentroid(final_dim=final_dim,
                                          quadtree_path=quad_tree_path,
                                          use_tanh=use_tanh,
                                          scale_tanh=scale_tanh)

attention_aggregation = AttentionWeightedAggregation(temperature=0.01) #TODO definde temperature
linear_attention = LinearAttention(attn_input_img_size=clip_embedding_size, text_features_size=clue_embedding_size, hidden_layer_size_0=1024, hidden_layer_size_1=1024)
country_classifier = CountryClassifier(clue_embedding_size=clue_embedding_size, alpha=alpha)

optimizer = optim.Adam(geohead_mid_network.parameters())
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5)

cell_loss = nn.CrossEntropyLoss()
coordinate_loss = nn.MSELoss()
pseudo_label_loss = nn.MSELoss()

#clues = load_dataset("gips-mai/all_clues_enc")
#descriptions = load_dataset("gips-mai/enc_descr")
#data_loader = torch.utils.data.DataLoader(descriptions, batch_size=32, shuffle=True)


In [59]:
data_loader = torch.utils.data.DataLoader(pd_dataset, batch_size=12)
print("done")

done


In [99]:
country_losses = []
geo_losses = []

for epoch in range(10):
    for batch in data_loader:
        descriptions, imgs, country_target, cell_target, coordinate_target = batch

        #imgs, descriptions, country_target, cell_target, coordinate_target = imgs.to(device), \
        #                                                                     descriptions.to(device), \
        #                                                                     country_target.to(device), \
        #                                                                     cell_target.to(device), \
        #                                                                     coordinate_target.to(device)

        optimizer.zero_grad()


        aggregated_input = torch.cat((imgs, descriptions), dim=1)

        attention = linear_attention.forward(img_embedding=imgs)
        weighted_aggregation = attention_aggregation.forward(clue_embeddings=clue_embeddings, attention=attention)

        country_loss = country_classifier.training_step(x=weighted_aggregation, target=country_target)
        # target: get the iso2 of actual country and then look at the one hot encoding
        country_losses.append(country_loss)

        # pseudo label loss
        current_pseudo_label_loss = pseudo_label_loss(country_target, attention)
        aux_attention_loss = alpha * current_pseudo_label_loss + (1-alpha) * country_loss

        location_prediction = hybrid_head_centroid.forward(geohead_mid_network.forward(aggregated_input), cell_target)

        current_coordinate_loss = coordinate_loss(location_prediction['gps'].float(), coordinate_target)
        current_cell_loss = cell_loss(location_prediction['label'], cell_target)
        geo_losses.append((current_cell_loss, current_coordinate_loss))

        total_loss = current_coordinate_loss + current_cell_loss + aux_attention_loss


        total_loss.backward()
        optimizer.step()

    scheduler.step()

    # Print the loss at each epoch
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}")

torch.float32
tensor(15610.1250, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(12904.0498, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(20201.5625, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(21698.9961, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(14969.1006, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(18674.1504, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(12874.8916, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(18108.3516, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(14438.4570, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(21996.5000, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(46849.3750, grad_fn=<MseLossBackward0>)
torch.float32
Epoch 1, Loss: 46858.9453
torch.float32
tensor(15607.9912, grad_fn=<MseLossBackward0>)
torch.float32
torch.float32
tensor(12903.2725, grad_fn=<MseLossBackward0>)
torch.float32