In [2]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.functional as nn

from tqdm import tqdm
from torch_geometric.data import Data
import sys
sys.path.append('../src')
from ice_graph.ice_graph import Ice_graph


## Notebook for creating gifs, images, maps...

In [3]:
files = sorted(os.listdir('../../week_data'))
file_graphs = []
for file in files:
    try:
        file_graphs.append(dict(np.load('../../week_data/' + file)))
    except:
        print(file)
len(files),len(file_graphs)

field_20230102T170000Z.npz


(336, 335)

In [4]:

#file_graphs = [dict(np.load(f'../../week_data/{file}')) for file in sorted(os.listdir('../../week_data')) if file[-3:]=='npz']
nextsim = Ice_graph(
    file_graphs,
    vertex_element_features =
        ['M_wind_x',
        'M_wind_y',
        'M_ocean_x',
        'M_ocean_y',
        'M_VT_x',
        'M_VT_y',
        'x',
        'y']
)

n_generations = 30000
predict_vel = True

radius = 400000 #meters
iterations = 1
time_index = 1 #index of element graph to fetch samples from
time_index_val = 12

validation_center = (0,0)
training_center = (0,0)
samples_train = nextsim.get_samples_area(training_center,radius,time_index=time_index,n_samples=n_generations,elements=False)
samples_val = nextsim.get_samples_area(validation_center,radius,time_index=time_index_val,n_samples=int(n_generations/5),elements=False)




In [71]:
#This cell plots the spatial distribution of node indexes incrementally. It helped me to understand the indexes

if not os.path.isdir('../figures/indexes'):
    os.mkdir('../figures/indexes')

vel_norm = np.sqrt(d0['M_VT_x']**2 + d0['M_VT_y']**2)
step = 1000
lenght = vel_norm.shape[0]


for i in range(0,lenght,step):
    fig,ax = plt.subplots(1,1,figsize=(10,10))
    ax.scatter(d0['x'][:i],d0['y'][:i], c=vel_norm[:i], s= 3 ,marker='.',linewidths=.7)
    ax.set_title(f'{i}/{lenght} points')
    plt.savefig(f'../figures/indexes/points_{i}.png')
    plt.close(fig)

In [46]:
#Simple "movie" to see that the ice actually moves :)

if not os.path.isdir('../figures/velocities'):
    os.mkdir('../figures/velocities')

for i,file in enumerate(file_graphs):
    # plot sea ice concentration (for each ELEMENT)
    vel_norm = np.sqrt(file['M_VT_x']**2 + file['M_VT_y']**2)

    fig = plt.figure(figsize=(10,10))
    plt.tripcolor(file['x'], file['y'], vel_norm, triangles=file['t'])
    plt.colorbar()
    plt.gca().set_aspect('equal')
    plt.savefig(f'../figures/velocities/snapshot_{i}.png')
    plt.close(fig)

In [16]:
np.isin(['a', 'b'],  ['c','d','a']).all()

False

In [5]:
interp.keys()

dict_keys(['M_wind_x', 'M_wind_y', 'M_ocean_x', 'M_ocean_y'])

In [5]:

full_graph = nextsim.get_vertex_centered_graph(
        samples_train[0],
        time_index = time_index,
        target_iter=iterations,
        include_vertex=True,
        velocity=predict_vel,
        n_neighbours=1
    )

center = full_graph[0].y[1]


for time_index in tqdm(range(1,len(file_graphs))):

    fet = [ 'Concentration', 'Thickness', 'x', 'y']

    full_graph = nextsim.get_vertex_centered_graph(
        samples_train[0],
        time_index = time_index,
        target_iter=iterations,
        e_features=fet,
        include_vertex=True,
        velocity=True,
        n_neighbours=20
    )


    e_g,v_g = full_graph

    plt.figure(figsize=(15,10))
    central_pos = v_g.y[1]


    for edge in v_g.edge_index.t():
        positions = np.array([v_g.pos.t()[edge[0]],v_g.pos.t()[edge[1]]])
        positions = positions.transpose()
        plt.plot(positions[0],positions[1],color='red',linewidth=.9)

    """

    for edge in e_g.edge_index.t():
        positions = np.array([e_g.pos.t()[edge[0]],e_g.pos.t()[edge[1]]])
        positions = positions.transpose()
        plt.plot(positions[0],positions[1],color='blue',linewidth=.5,alpha=.5)
    """
    #plt.scatter(e_g.pos[0],e_g.pos[1],color='blue',label='Element',alpha=.5, s= 10)
    plt.scatter(v_g.pos[0],v_g.pos[1],color='red',label='Vertex', s= 4, alpha=.5)

    plt.scatter(central_pos[0],central_pos[1],marker='x', s=100, linewidth=5, c='red', label='Central vertex')


    interp = nextsim.get_forcings(time_index, features=['M_VT_x','M_VT_y'])
    X = np.linspace(center[0]-radius,center[0]+radius,num=1000)
    Y = np.linspace(center[1]-radius,center[1]+radius,num=1000)
    X, Y = np.meshgrid(X, Y)  # 2D grid for interpolation
    Z_x = interp['M_VT_x'](X,Y)
    Z_y = interp['M_VT_y'](X,Y)
    norm = np.sqrt(Z_x**2 + Z_y**2)
    plt.contourf(X,Y,norm,levels=30,cmap='viridis',alpha=.7,label='Sea ice velocity (m/s)')

    plt.xlim(center[0]-radius/2,center[0]+radius/2)
    plt.ylim(center[1]-radius/2,center[1]+radius/2)

    plt.legend()
    plt.title(f"Graph at time index {time_index}")
    plt.colorbar(label='Sea ice velocity (m/s)')
    plt.savefig(f'../figures/graph_evolution/graph_{time_index}.png')

  plt.contourf(X,Y,norm,levels=30,cmap='viridis',alpha=.7,label='Sea ice velocity (m/s)')
  plt.figure(figsize=(15,10))
 15%|█▍        | 49/334 [10:08<1:15:14, 15.84s/it]