#### Set up

In [2]:
import rasterio as rio
import numpy as np
import pandas as pd
from glob import glob
import os
import math
from tqdm import tqdm
import ee 
import ee_utils
import tensorflow as tf
import pickle
import gzip
import shutil
from tfrecord.torch.dataset import TFRecordDataset
import torch

2023-10-02 09:48:17.470071: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-02 09:48:17.541287: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


#### Retreive Temperatures (Earth Engine LST) 

In [None]:
try:
    ee.Initialize()
except Exception as e:
    ee.Authenticate()
    ee.Initialize()

In [None]:
# ========== ADAPT THESE PARAMETERS ==========
# To export to Google Drive, uncomment the next 2 lines
EXPORT = ''
BUCKET = None
# export location parameters
ERA5_EXPORT_FOLDER = ''
CSV_PATH = '../data/dataset_2013+.csv'
BANDS = ['mean_2m_air_temperature', 'minimum_2m_air_temperature', 'maximum_2m_air_temperature']
# image export parameters
PROJECTION = 'EPSG:3857'  # see https://epsg.io/3857
SCALE = 30                # export resolution: 30m/px
EXPORT_TILE_RADIUS = 3  # We only need the central values here
CHUNK_SIZE = None    # set to a small number (<= 50) if Google Earth Engine reports memory errors; 
csv = pd.read_csv(CSV_PATH)

In [None]:
def export_images(
        df: pd.DataFrame,
        collection: ee.ImageCollection,
        country: str,
        year: int,
        export_folder: str,
        chunk_size = 1,
 ):
    '''
    Args
    - df: pd.DataFrame, contains columns ['lat', 'lon', 'country', 'year']
    - country: str, together with `year` determines the survey to export
    - year: int, together with `country` determines the survey to export
    - export_folder: str, name of folder for export
    - chunk_size: int, optionally set a limit to the # of images exported per TFRecord file
        - set to a small number (<= 50) if Google Earth Engine reports memory errors

    Returns: dict, maps task name tuple (export_folder, country, year, chunk) to ee.batch.Task
    '''

    subset_df = df[(df['country'] == country) & (df['year'] == year)].reset_index(drop=True)
    if chunk_size is None:
        chunk_size = len(subset_df)
    num_chunks = int(math.ceil(len(subset_df) / chunk_size))
    tasks = {}

    for i in range(num_chunks):
        chunk_slice = slice(i * chunk_size, (i+1) * chunk_size - 1)  # df.loc[] is inclusive
        fc = ee_utils.df_to_fc(subset_df.loc[chunk_slice, :])
        for prev_year in range(year-4, year+1):
            for month in ['01','02','03','04','05','06','07','08','09','10','11','12']:
                start_date, end_date = str(prev_year)+'-'+month+'-01',str(prev_year)+'-'+month+'-28'
                roi = fc.geometry()
                collection_ave = collection.select(BANDS[0]).filterDate(start_date, end_date).filterBounds(roi)
                ave = collection_ave.median()
                ave = ee_utils.add_latlon(ave)

                fname = f'{country}_{year}_{prev_year}_{i:02d}'
                tasks[(export_folder, country, prev_year, i)] = ee_utils.get_array_patches(
                        img=ave, scale=SCALE, ksize=EXPORT_TILE_RADIUS,
                        points=fc, export='drive',
                        prefix=export_folder, fname=fname+'_'+month,
                        bucket=None)
        return tasks

In [None]:
collection = ee.ImageCollection("ECMWF/ERA5/MONTHLY")
dataset = pd.read_csv('../data/dataset_2013+.csv')
dataset_ = list(dataset.groupby(['country', 'year']).groups.keys())
tasks = {}
for country, year in tqdm(dataset_):
    print(country, year)
    new_tasks = export_images(
        df=dataset, collection=collection, country=country, year=year,
        export_folder=ERA5_EXPORT_FOLDER, chunk_size=CHUNK_SIZE)
    tasks.update(new_tasks)


In [None]:
REQUIRED_BANDS = ['minimum_2m_air_temperature', 'maximum_2m_air_temperature','mean_2m_air_temperature']

BANDS_ORDER = ['minimum_2m_air_temperature', 'maximum_2m_air_temperature','mean_2m_air_temperature']


EXPORT_FOLDER = '../data/additional_data/temperature'
PROCESSED_FOLDER = '../data/additional_data/temperature'
def validate_and_split_tfrecords(
        tfrecord_paths,
        out_dir: str,
        df: pd.DataFrame,
        country,
        year
        ) -> None:
    '''Validates and splits a list of exported TFRecord files (for a
    given country-year survey) into individual TFrecords, one per cluster.

    "Validating" a TFRecord comprises of 2 parts
    1) verifying that it contains the required bands
    2) verifying that its other features match the values from the dataset CSV

    Args
    - tfrecord_paths: list of str, paths to exported TFRecords files
    - out_dir: str, path to dir to save processed individual TFRecords
    - df: pd.DataFrame, index is sequential and starts at 0
    '''
    # Create an iterator over the TFRecords file. The iterator yields
    # the binary representations of Example messages as strings.
    options = tf.io.TFRecordOptions(compression_type = 'GZIP')

    # cast float64 => float32 and str => bytes
    for col in df.columns:
        if df[col].dtype == np.float64:
            df[col] = df[col].astype(np.float32)
        elif df[col].dtype == object:  # pandas uses 'object' type for str
            df[col] = df[col].astype(bytes)

   
    progbar = tqdm(total=len(df))

    for tfrecord_path in tfrecord_paths:
        iterator = tf.compat.v1.io.tf_record_iterator(tfrecord_path, options=options)
        for record_str in iterator:
            # parse into an actual Example message
            ex = tf.train.Example.FromString(record_str)
            feature_map = ex.features.feature
            index = str(int(feature_map["cluster"].float_list.value[0]))
            for band in REQUIRED_BANDS:
                assert band in feature_map, f'Band "{band}" not in record {index} of {tfrecord_path}'
#             serialize to string and write to file
            month = tfrecord_path[-14:-12]
            out_path = os.path.join(out_dir, f'{index}'+"_"+month+'.tfrecord.gz')  # all surveys have < 1e6 clusters
            with tf.io.TFRecordWriter(out_path, options=options) as writer:
                writer.write(ex.SerializeToString())
            progbar.update(1)
    progbar.close()
    

def process_dataset(csv_path: str, input_dir: str, processed_dir: str) -> None:
    '''
    Args
    - csv_path: str, path to CSV of DHS or LSMS clusters
    - input_dir: str, path to TFRecords exported from Google Earth Engine
    - processed_dir: str, folder where to save processed TFRecords
    '''
    df = pd.read_csv(csv_path, float_precision='high', index_col=False)
    surveys = list(df.groupby(['country', 'year']).groups.keys())  # (country, year) tuples
   
    # print(year, type(year))
    for country, year in surveys:        
        # Checking inside potential subfolders
        for prev_year in range(year-4, year+1):
            country_year = f'{country}_{year}_{prev_year}'
            print('Processing:', country_year)
            tfrecord_paths = glob(os.path.join(input_dir, country_year+'*.tfrecord.gz'))
            tfrecord_paths += glob(os.path.join(input_dir, "*", country_year + '*.tfrecord.gz'))
            tfrecord_paths += glob(os.path.join(input_dir, "*","*", country_year + '*.tfrecord.gz'))

            out_dir = os.path.join(processed_dir, country_year)
            os.makedirs(out_dir, exist_ok=True)
            subset_df = df[(df['country'] == country) & (df['year'] == year)].reset_index(drop=True)
            validate_and_split_tfrecords(
            tfrecord_paths=tfrecord_paths, out_dir=out_dir, df=subset_df, country=country, year=year)

In [None]:
process_dataset(
    csv_path='../data/dataset_viirs_only.csv',
    input_dir=EXPORT_FOLDER,
    processed_dir=PROCESSED_FOLDER
)

In [None]:
CSV              = os.path.join( "..", "data", "dataset_viirs_only.csv" )
RECORDS_DIR      = os.path.join( "..", "data", "additional_data", "temperature", "")
TIF_DIR          = os.path.join( "..", "data", "additional_data", "temperature", "" )

csv = pd.read_csv(CSV)
records = dict()
for year in csv.year.unique():
    sub_year = csv[ csv.year == year ]
    for prev_year in range(year-4, year+1):
        records[year, prev_year] = dict()
        for country in sub_year.country.unique():
            sub_country = sub_year[ sub_year.country == country ].copy()
            pattern ='../'+RECORDS_DIR+str(country)+"_"+str(year)+"_"+str(prev_year)+"/"
            records[year,prev_year][country] = glob(pattern)
            print(pattern)

    break
def decompress_tfrecord(tfrecord_archive):
    with gzip.open(tfrecord_archive, 'rb') as f_in:
        # WITHOUT .GZ
        with open(tfrecord_archive[:-3], 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return tfrecord_archive[:-3]

def tensor_to_string(data, variable):
    filename = (data[variable].numpy())[0][0]
    return str(filename).replace(".","")

records

In [None]:
DESCRIPTOR       = {
                'cluster':"float",
                'lat':"float", 
                "lon":"float",
                'wealthpooled':"float",
                'minimum_2m_air_temperature':'float',
                'maximum_2m_air_temperature':'float',
                'mean_2m_air_temperature':'float'
    
              } 

BANDNAMES = ['mean_2m_air_temperature', 'minimum_2m_air_temperature','maximum_2m_air_temperature']

def tfrecord_to_tif(data, filename, mins, maxs):
    arrays = [] 
    for i in range(3):
        new_arr = data[BANDNAMES[i]].numpy().reshape((3,3))
        arrays.append(new_arr)
        mins[i] = min(mins[i], new_arr.min())
        maxs[i] = max(maxs[i], new_arr.max())
    
    arr = np.swapaxes(np.array(arrays), 0, 2 )
    tif_path = TIF_DIR + filename
    tif = rasterio.open(tif_path, 'w', driver='GTiff',
                            height = arr.shape[0], width = arr.shape[1],
                            count=8, dtype=str(arr.dtype),
                            crs='epsg:3857',
                            transform=None)
    tif.write(arr[:,:,0],1)
    tif.close()
    return mins, maxs

def read_record(data):
    result = []
    for i in range(3):
        # READ CENTRAL VALUE
        result.append(data[BANDNAMES[i]].numpy().reshape((3,3))[1,1])
    return result


In [None]:
temperatures = dict()
month_to_index = {
        '01':0,
        '02':1,
        '03':2,
        '04':3,
        '05':4,
        '06':5,
        '07':6,
        '08':7,
        '09':8,
        '10':9,
        '11':10,
        '12':11,
    }


In [None]:
mean_temperature = 0.
std_temperature = 0.

In [None]:
for year, prev_year in records:
    print(year, prev_year)
    for country in records[year, prev_year]:
        print(country)
        for tfrecord_archive in tqdm(records[year,prev_year][country]):
            if tfrecord_archive[-3:] == '.gz':
                tfrecord = decompress_tfrecord(tfrecord_archive=tfrecord_archive)
                tfrecord = tfrecord_archive[:-3]
            else:
                tfrecord = tfrecord_archive
            dataset = TFRecordDataset(tfrecord, index_path=None, description=DESCRIPTOR)
            loader = torch.utils.data.DataLoader(dataset, batch_size=1)
            iterator = iter(loader)
            tfrecord = tfrecord.split('/')[-1]
            month=tfrecord[-11:-9] 
            cluster=tfrecord[:-12]
            while (data := next(iterator, None)) is not None:
                val = read_record(data)
                print('VAL  :::: ', val)
                if (country, year, prev_year, int(cluster)) not in temperatures:
                    temperatures[ (country, year, prev_year, int(cluster)) ] = [[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.],[0.,0.,0.]]
                temperatures[(country, year, prev_year, int(cluster))][month_to_index[month]]=val
        with open('temperatures.pickle', 'wb') as handle:
            pickle.dump(temperatures, handle, protocol=pickle.HIGHEST_PROTOCOL)
            

##### Check temperature dict

In [None]:
import pickle
with open('../data/additional_data/temperatures.pickle', 'rb') as handle:
    temperatures = pickle.load( handle)

In [None]:
del temperatures['mins']

In [None]:
arr = np.array([])
for k in temperatures.keys():
    if arr==np.array([]):
        arr = np.array(temperatures[k]).flatten()
    else:
        arr = np.concatenate((arr, np.array(temperatures[k]).flatten()))
arr

In [None]:
mean=np.mean(arr)
std=np.std(arr)
mean,std

#### Retreive Precipitations (FOA WAPOR PCP - from CHIRPS catalog)

##### PRECIPITATIONS - not expanded

In [None]:
dataset = pd.read_csv('../data/dataset_2013+.csv')
PATH = os.path.join('../data', 'additional_data','precipitation')

In [None]:
dataset.head()
dataset.to_csv('../data/dataset_additional.csv', index=False)

In [None]:
dataset["precipitation"] = ''

In [None]:
def correct_island_coordinates(x,y):
    '''return the closest valid points when dealing with islands due to coarse tif resolution'''
    # Tanzania
    if int(y)==-5 or int(y)==-6 and int(x)==39:
        return 39.29, -5.98
    # Sierra Leone 
    if (int(x)==-13 or int(x)==-12) and (int(y)==8 or int(y)==7):
        return -12.7, 7.8
    # Senegal 
    if int(x) in [-12,-13,-16,-17] and int(y) in [12,13,14,15,16]:
        return x-1, y
    # Mozambique 
    if (int(x), int(y)) in [(32,-25),(40,-12)]:
        return int(x), int(y)
    if (int(x), int(y))== (34, -19):
        return 34.80, -19.80
    # Madagascar
    if (int(x)==43 and int(y)==-23):
        return x+1, y
    if (int(x) in (48,49) and int(y) in (-12,-13)):
        return x, y-2 
    # Guinea 
    if int(x)==-13 and int(y)==9:
        return x+0.5, y+0.5
    # Cote d'Ivoir
    if int(x)==-6 and int(y)==4:
        return x+0.5, y+0.5
    # Benin
    if int(x) in (1,2) and int(y)==6:
        return x, y+0.5
    # Angola
    if int(x)==-13 and int(y) in (-8,-12):
        return x+0.5, y
    return x,y

In [None]:
# Creates precipitation pikle dictionary with manually corrected island coordinates

precipitations = dict()
dataset = pd.read_csv('../data/dataset_viirs_only.csv')

for year in tqdm(dataset.year.unique()):
    df = dataset[dataset.year==year]
    print(year)
    for prev_year in range(year-4,year+1):
        # sorted months
        monthly_tifs = glob.glob(os.path.join(PATH, str(prev_year)+"*.tif"))
        for tif in monthly_tifs:
            with rio.open(tif) as src: 
                for country in df.country.unique():
                    df_country = df[df.country==country]
                    for cluster in df_country.cluster.unique():
                        if (country,year,prev_year, cluster) not in precipitations:
                            precipitations[country,year,prev_year, cluster] = []
                        if precipitations[(country,year,prev_year, cluster)] == None:
                            precipitations[country,year,prev_year, cluster] = []
                        vector = []
                        row = df_country.loc[(df_country['cluster'] == cluster)]
                        x = float(row.at[row.index[0],'lon'])
                        y = float(row.at[row.index[0],'lat'])
                        x, y = correct_island_coordinates(x,y)      
                        for val in src.sample([(x, y)]): 
                            # THE ORIGINAL RASTERS HAVE TOO COARSE RESOLUTION TO CAPTURE SMALL ISLANDS AND COASTLINES AS VALID COORDINATES
                            # WE TAKE THE THE CLOSEST COASTAL POINT IN THIS CASE 
                            vector=val[0]
                        precipitations[(country,year,prev_year, cluster)].append(vector)
                with open('../data/additional_data/precipitations.pickle', 'wb') as handle:
                    pickle.dump(precipitations, handle, protocol=pickle.HIGHEST_PROTOCOL)

##### PRECIPITATIONS - expanded

In [None]:
# Creates precipitation pikle dictionary with expanded precipitation rasters (using same method as non expanded precipitations)

precipitations = dict()
dataset = pd.read_csv('../data/dataset_2013+.csv')
PATH = os.path.join('../data', 'additional_data','precipitation')

for year in tqdm(dataset.year.unique()):
    df = dataset[dataset.year==year]
    print(year)
    for prev_year in range(year-4,year+1):
        # sorted months
        monthly_tifs = glob.glob(os.path.join(PATH, str(prev_year)+"*.tif"))
        for tif in monthly_tifs:
            with rio.open(tif) as src: 
                for country in df.country.unique():
                    df_country = df[df.country==country]
                    for cluster in df_country.cluster.unique():
                        if (country,year,prev_year, cluster) not in precipitations:
                            precipitations[country,year,prev_year, cluster] = []
                        if precipitations[(country,year,prev_year, cluster)] == None:
                            precipitations[country,year,prev_year, cluster] = []
                        vector = []
                        row = df_country.loc[(df_country['cluster'] == cluster)]
                        x = float(row.at[row.index[0],'lon'])
                        y = float(row.at[row.index[0],'lat'])
                        #x, y = correct_island_coordinates(x,y)      
                        for val in src.sample([(x, y)]): 

                            vector=val[0]
                        precipitations[(country,year,prev_year, cluster)].append(vector)
                with open('../data/additional_data/precipitations.pickle', 'wb') as handle:
                    pickle.dump(precipitations, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Check out created dict.

with open('../data/additional_data/precipitation_1.pickle', 'rb') as f:
    dic = pickle.load(f)

dic

#### Add new variable from raster

In [None]:
dataset = pd.read_csv('../data/dataset_2013+.csv')
DATA_DIR='../data/additional_data/conflicts/'
DICT_NAME='../data/additional_data/conflict.pickle'

In [None]:
# We subdivide the dataset rather than iterating through rows naively
# to limit the number of raster-reading operation, which is the bottleneck of this preprocessing step
series_dict = dict()
for year in tqdm(dataset.year.unique()):
    df = dataset[dataset.year==year]
    print(year)
    for prev_year in range(year-4,year+1):
        # sorted months
        # we leave the tif files opened and sample all observations at once
        tif = glob.glob(os.path.join(DATA_DIR, str(prev_year)+"*.tif"))[0]    
        with rio.open(tif) as src: 
            for country in df.country.unique():
                df_country = df[df.country==country]
                for cluster in df_country.cluster.unique():
                    if (country,year,prev_year, cluster) not in series_dict:
                        series_dict[country,year,prev_year, cluster] = []
                    if series_dict[(country,year,prev_year, cluster)] == None:
                        series_dict[country,year,prev_year, cluster] = []
                    vector = []
                    row = df_country.loc[(df_country['cluster'] == cluster)]
                    x = float(row.at[row.index[0],'lon'])
                    y = float(row.at[row.index[0],'lat'])
                    # SEVERAL BANDS ?
                    for val in src.sample([(x, y)]): 
                        vector=val[0]
                    series_dict[(country,year,prev_year, cluster)].append(vector)
            with open(DICT_NAME, 'wb') as handle:
                pickle.dump(series_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Replace neg values with 0 in pickle dict

DICT_NAME='../data/additional_data/conflict.pickle'

# Open pickle file
with open(DICT_NAME, 'rb') as pickle_file:
    loaded_dict = pickle.load(pickle_file)

# Function to replace neg values with 0
def replace_neg(value):

    # If value is a list
    if isinstance(value, list):
        return [max(0,x) for x in value]
    else:
        return max(0, value)
    
# Function to recursively apply the replacement fct to dict
def replace_recursive(d):
    for key, value in d.items():
        if isinstance(value, dict):
            d[key] = replace_recursive(value)
        else: 
            d[key] = replace_neg(value)
    return d

# Replace negatove values with 0 in the dictionnary 
conflict_noneg = replace_recursive(loaded_dict)

# Save the new dict 
output_PATH = "../data/additional_data/conflict_noneg.pickle"

with open(output_PATH, 'wb') as file:
    pickle.dump(loaded_dict, file)



print("Neg values replaced with 0 in new pickle dict")


#### Save `mean` & `std` to normalizer 

In [63]:
# For precipitations

DICT_NAME='../data/additional_data/precipitation_1.pickle'
VARIABLE='precipitation'

from helper import get_mean_std_from_dict

with open(DICT_NAME, 'rb') as f:
    series_dict = pickle.load(f)
mean, std = get_mean_std_from_dict(series_dict)
mean, std

with open('../datasets/normalizer.pkl', 'rb') as f:
    normalizer = pickle.load(f)
normalizer[VARIABLE] = mean, std
with open('../datasets/normalizer.pkl', 'wb') as f:
    pickle.dump(normalizer, f)

min:  0.0 max:  17320.0


In [None]:
# For conflict

DICT_NAME='../data/additional_data/conflict_noneg.pickle'
VARIABLE='conflict'

from helper import get_mean_std_from_dict

with open(DICT_NAME, 'rb') as f:
    series_dict = pickle.load(f)
mean, std = get_mean_std_from_dict(series_dict)
mean, std


with open('../datasets/normalizer.pkl', 'rb') as f:
    normalizer = pickle.load(f)
normalizer[VARIABLE] = mean, std
with open('../datasets/normalizer.pkl', 'wb') as f:
    pickle.dump(normalizer, f)

In [7]:
with open('../datasets/normalizer.pkl', 'rb') as f:
    normalizer = pickle.load(f)
normalizer

{'landsat_+_nightlights': (array([5.17201265e-02, 8.82524713e-02, 1.02674783e-01, 2.64303597e-01,
         2.49776030e-01, 1.71501024e-01, 3.00020234e+02, 1.59417214e+00]),
  array([6.19937578e-04, 1.27377769e-03, 3.33014929e-03, 4.55055797e-03,
         9.47793062e-03, 8.13588032e-03, 2.26191530e+01, 1.01391193e+00])),
 'temperature': (299.78521219889325, 2.5680043231102125),
 'conflict': (2.8288902449655655, 31.847525639418162),
 'precipitation': (1029.2007072616948, 1219.951013718319)}

##### Remove 'random' from Normalizer

In [None]:
with open('../datasets/normalizer_random.pkl', 'rb') as f:
    normalizer = pickle.load(f)
key_to_remove =['random']
for key in key_to_remove:
    if key in normalizer:
        del normalizer[key]

# Save the modified dictionary
with open('../datasets/normalizer.pkl', 'wb') as f:
    pickle.dump(normalizer, f)



##### Why do we have negative values in conflict dict ?

In [None]:
# Nb of NA in conflict RASTERS

import rasterio

Raster_PATH  = "../data/additional_data/conflicts/2019.tif"

with rasterio.open(Raster_PATH) as src:
    raster_data = src.read(1)

    na_count = (raster_data == src.nodata).sum()

    print(na_count)

In [None]:
# Nb of pixels <0 in conflict RASTERS

Raster_PATH  = "../data/additional_data/conflicts/2018.tif"

with rasterio.open(Raster_PATH) as src:
    raster_data = src.read(1)

    neg_count = (raster_data <0).sum()

    print(neg_count)

In [None]:
# Print conflict dict

DICT_NAME='../data/additional_data/conflict.pickle'
with open(DICT_NAME, 'rb') as pickle_file:
    loaded_dict = pickle.load(pickle_file)

for key, value in loaded_dict.items():
    print(f"Key: {key}, Value: {value}")

In [None]:
# Nb of pixels <0 in conflict PICKLE

DICT_NAME='../data/additional_data/conflict.pickle'
with open(DICT_NAME, 'rb') as pickle_file:
    loaded_dict = pickle.load(pickle_file)
    

for key, value in loaded_dict.items():
    if isinstance(value, list):
        for element in value:
            if element < 0:
                print(f"Key: {key}, Element: {element}")
    elif value < 0:
        print(f"Key: {key}, Value: {value}")

In [None]:
# Meand and STD of pixels <0 in conflict PICKLE

DICT_NAME='../data/additional_data/conflict.pkl'
with open(DICT_NAME, 'rb') as pickle_file:
    loaded_dict = pickle.load(pickle_file)
    
negative_elements = []

# Iterate through the dict and collect neg elements
for key, value in loaded_dict.items():

    # Value is a list, so check each element in the list
    if isinstance(value, list):
        for element in value:
            if element < 0:
                # Round neg 
                element = round(element, 3)
                negative_elements.append(element)

    # In case value is not a list
    elif value < 0:
        # Round neg 
        value = round(value, 3)
        negative_elements.append(value)

# Calculate mean and std of neg values and elements
if negative_elements:
    mean_neg = np.mean(negative_elements)
    std_neg = np.std(negative_elements)
    print(f"Mean of neg elements: {mean_neg:.3f}")
    print(f"STD of neg elements: {std_neg:.3f}")
else:
    print("No neg elem")


#### Determine best epoch for each fold

In [4]:
# function to print the structure of a dic

def print_dict_str(d, indent=0):
    for key, value in d.items():
        if isinstance(value, dict):
            print(" " * indent + f"{key}: (dict)")
            print_dict_str(value, indent + 4)
        else:
            print(" " * indent + f"{key}: {type(value).__name__}")


In [5]:
# print structure of a dic

with open('../models/results/conf_t/msnlt_conf_t2.pkl', 'rb') as f:
    dic = pickle.load(f)

print_dict_str(dic)


A: (dict)
    train_loss: list
    train_r2: list
    test_loss: list
    test_r2: list
B: (dict)
    train_loss: list
    train_r2: list
    test_loss: list
    test_r2: list
C: (dict)
    train_loss: list
    train_r2: list
    test_loss: list
    test_r2: list
D: (dict)
    train_loss: list
    train_r2: list
    test_loss: list
    test_r2: list
E: (dict)
    train_loss: list
    train_r2: list
    test_loss: list
    test_r2: list


In [8]:
# Best epoch

with open('../models/results/conf_t/ts_conf_t3.pkl', 'rb') as f:
    dic = pickle.load(f)


best_epochs = {}

# Iterate through the folds
for fold, fold_results in dic.items():
    # Find the epoch with the minimum test_loss
    best_test_loss_epoch = np.argmin(fold_results['test_loss']) + 1  # Adjust for 0-based index
    
    # Find the epoch with the maximum test_r2
    best_test_r2_epoch = np.argmax(fold_results['test_r2']) + 1  # Adjust for 0-based index
    
    best_epochs[fold] = {
        'best_test_loss_epoch': best_test_loss_epoch,
        'best_test_r2_epoch': best_test_r2_epoch
    }

'''
print("Best epochs based on test_loss:")
for fold, epochs in best_epochs.items():
    print(f"Fold {fold}: Epoch {epochs['best_test_loss_epoch']} (Test Loss)")
'''

print("Best epochs based on test_r2:")
for fold, epochs in best_epochs.items():
    print(f"Fold {fold}: Epoch {epochs['best_test_r2_epoch']} (Test R-squared)")


Best epochs based on test_r2:
Fold A: Epoch 60 (Test R-squared)
Fold B: Epoch 60 (Test R-squared)
Fold C: Epoch 55 (Test R-squared)
Fold D: Epoch 57 (Test R-squared)
Fold E: Epoch 59 (Test R-squared)


In [24]:
# Test name 
test_name = 'ts_conf_t100'

In [23]:
# Make a dico with best epochs and print them

# Load the results from the pickle file
with open('../models/results/conf_t/' + test_name + '.pkl', 'rb') as f:
    dic = pickle.load(f)

best_epochs = {}  # Dictionary to store the best epochs

# Iterate through the folds
for fold, fold_results in dic.items():
    # Find the epoch with the maximum test_r2
    best_test_r2_epoch = np.argmax(fold_results['test_r2']) + 1  # Adjust for 0-based index
    
    best_epochs[fold] = {
        'best_test_r2_epoch': best_test_r2_epoch
    }

# Dictionary containing the best epochs for each fold
best_epochs_dict = {} 

# Define the model names ('A', 'B', 'C', 'D', 'E')
model_names = ['A', 'B', 'C', 'D', 'E']

# Iterate through the model names
for model_name in model_names:
    # Get the best epochs for the current model
    model_best_epochs = {
        'best_test_r2_epoch': best_epochs[model_name]['best_test_r2_epoch']
    }
    
    # Add the best epochs to the dictionary
    best_epochs_dict[model_name] = model_best_epochs

# Print the best epochs based on test_r2 in the specified format
print("Best epochs based on test_r2:")
for model_name, epochs in best_epochs_dict.items():
    print(f"Fold {model_name}: Epoch {epochs['best_test_r2_epoch']} (Test R-squared)")


Best epochs based on test_r2:
Fold A: Epoch 60 (Test R-squared)
Fold B: Epoch 60 (Test R-squared)
Fold C: Epoch 55 (Test R-squared)
Fold D: Epoch 57 (Test R-squared)
Fold E: Epoch 59 (Test R-squared)


In [None]:
# Save the best epoch checkpoint files in a folder

# Path to the folder containing all checkpoints
all_checkpoints_folder = '../models/checkpoints/conf_t/all'

# Path to the folder where the best checkpoints will be copied 
best_checkpoints_folder = '../models/checkpoints/conf_t/'+test_name+'_best '

# Iterate through the folds and models to copy the best checkpoints
for fold, best_epoch in best_epochs.items():
    # Model names 
    model_names = ['A', 'B', 'C', 'D', 'E']
    
    for model_name in model_names:
        # Construct the checkpoint filename based on the provided structure
        checkpoint_filename = f"{test_name}_{model_name}_{fold}_{best_epoch}.pth"
        
        # Source path of the best checkpoint file
        source_path = os.path.join(all_checkpoints_folder, checkpoint_filename)
        
        # Destination path where the best checkpoint will be copied
        destination_path = os.path.join(best_checkpoints_folder, checkpoint_filename)
        
        # Copy the best checkpoint to the 'best' folder
        shutil.copy(source_path, destination_path)

# Inform the user that the operation is completed
print("Best checkpoints copied to 'best' folder.")

