In [None]:
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import os
import pandas as pd
import torch
import seaborn as sns
import ast

In [None]:
def load_data(REPO_PATH, dataset_name, country_list):
    # Directory containing CSV files
    directory = f'{REPO_PATH}/Embeddings/Image/'

    # Get a list of all filenames in each directory
    file_list = [file for file in os.listdir(directory) if file.startswith(dataset_name)]

    # Initialize an empty list to store DataFrames
    dfs = []


    # Iterate through the files, read them as DataFrames, and append to the list
    for file in file_list:
        file_path = os.path.join(directory, file)
        df = pd.read_csv(file_path)
        dfs.append(df)

    # Concatenate all DataFrames in the list into a single DataFrame
    combined_df = pd.concat(dfs, ignore_index=True)

    # map contries to regions
    combined_df = pd.merge(combined_df,country_list,left_on='label',right_on='Country')
    
    return combined_df

In [None]:
def transform_to_tensor(x):
    # model_input_str = eval(x)
    # model_input_array = np.frombuffer(model_input_str, dtype=np.float32)
    # model_input = model_input_array.tolist()
    # return model_input
    start = x.find('[[')
    end = x.find(']]')+2
    x = x[start:end]
    image_embedding = torch.tensor(ast.literal_eval(x))
    image_embedding_values = np.array(image_embedding.flatten().tolist()).reshape(1, -1)
    return image_embedding_values

In [None]:
REPO_PATH = '/home/kieran/Documents/Uni/WiSe23-24/Good_Practices_of_Machine_Learning/good_practices_ml/'
country_list = pd.read_csv(f'{REPO_PATH}/data_finding/country_list_region_and_continent.csv')

#set dataset_name
dataset_name = 'bigfoto'

#Load Data
combined_df = load_data(REPO_PATH, dataset_name, country_list)
y = combined_df['label'].to_numpy()
regions = []
continents = []
for elem in y:
    country_row = country_list.loc[country_list['Country'] == elem].iloc[0]
    regions.append(country_row['Intermediate Region Name'])
    continents.append(country_row['Continent'])

# Get sets of country, region and continent classes
country_classes = np.unique(y)
region_classes = np.unique(regions)
continent_classes = np.unique(continents)

#Create numpy array for TSNE
X = combined_df["Embedding"].apply(transform_to_tensor).to_list()
feat_cols = [ 'pixel'+str(i) for i in range(len(X[0])) ]
X = np.array(X)
X = np.squeeze(X)
print(X.shape)

# Run TSNE
tsne = TSNE(n_components=2, verbose=1, init='pca')
tsne_results = tsne.fit_transform(X)

In [None]:
df_subset = pd.DataFrame()
df_subset['tsne-2d-one'] = tsne_results[:,0]
df_subset['tsne-2d-two'] = tsne_results[:,1]
df_subset['Classes'] = continents

plt.figure(figsize=(16,10), clear=True)
scatterplot = sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="Classes",
    palette=sns.color_palette("hls", 6),
    data=df_subset,
    legend="full",
    alpha=0.3
)
if not os.path.isdir(f'./{dataset_name}'):
    os.mkdir(f'./{dataset_name}')
if not os.path.isdir(f'./{dataset_name}/World'):
    os.mkdir(f'./{dataset_name}/World')
scatterplot.figure.savefig(f'./{dataset_name}/World/output.png')

In [None]:
continent_specific_labels = []
for continent in continent_classes:
    result_array = []
    for elem in y:
        country_row = country_list.loc[country_list['Country'] == elem].iloc[0]
        if country_row['Continent'] == continent:
            result_array.append(country_row['Intermediate Region Name'])
        else:
            result_array.append('Other')
    continent_specific_labels.append(result_array)

In [None]:
df_subset = pd.DataFrame()
df_subset['tsne-2d-one'] = tsne_results[:,0]
df_subset['tsne-2d-two'] = tsne_results[:,1]
if not os.path.isdir(f'./{dataset_name}/Continent'):
    os.mkdir(f'./{dataset_name}/Continent')
for i in range(0, len(continent_specific_labels)):
    df_subset['Classes'] = continent_specific_labels[i]

    unique_classes = np.unique(continent_specific_labels[i])
    modified_unique = list(np.delete(unique_classes, np.where(unique_classes == 'Other')))
    color_palette = sns.color_palette("hls", len(modified_unique)).as_hex()
    color_palette.append('#d3d3d3')
    modified_unique.append('Other')

    plt.figure(figsize=(16,10), clear=True)
    scatterplot = sns.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="Classes",
        hue_order=modified_unique,
        palette=sns.color_palette(color_palette),
        data=df_subset,
        legend="full",
        alpha=0.3
    )
    if not os.path.isdir(f'./{dataset_name}/Continent/{continent_classes[i]}'):
        os.mkdir(f'./{dataset_name}/Continent/{continent_classes[i]}')
    scatterplot.figure.savefig(f'./{dataset_name}/Continent/{continent_classes[i]}/output.png')

In [None]:
region_specific_labels = []
for region in region_classes:
    result_array = []
    for elem in y:
        country_row = country_list.loc[country_list['Country'] == elem].iloc[0]
        if country_row['Intermediate Region Name'] == region:
            result_array.append(elem)
        else:
            result_array.append('Other')
    region_specific_labels.append(result_array)

In [None]:
df_subset = pd.DataFrame()
df_subset['tsne-2d-one'] = tsne_results[:,0]
df_subset['tsne-2d-two'] = tsne_results[:,1]
if not os.path.isdir(f'./{dataset_name}/Region'):
    os.mkdir(f'./{dataset_name}/Region')
for i in range(0, len(region_specific_labels)):
    df_subset['Classes'] = region_specific_labels[i]

    unique_classes = np.unique(region_specific_labels[i])
    modified_unique = list(np.delete(unique_classes, np.where(unique_classes == 'Other')))
    color_palette = sns.color_palette("hls", len(modified_unique)).as_hex()
    color_palette.append('#d3d3d3')
    modified_unique.append('Other')

    plt.figure(figsize=(16,10), clear=True)
    scatterplot = sns.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="Classes",
        hue_order=modified_unique,
        palette=sns.color_palette(color_palette),
        data=df_subset,
        legend="full",
        alpha=0.3
    )
    if not os.path.isdir(f'./{dataset_name}/Region/{region_classes[i]}'):
        os.mkdir(f'./{dataset_name}/Region/{region_classes[i]}')
    scatterplot.figure.savefig(f'./{dataset_name}/Region/{region_classes[i]}/output.png')