# Загружаем архив и датасет:

In [None]:
# подключаем гугл диск
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# проверяем, что у нас есть архив
!ls /content/drive/MyDrive/GNN/

In [None]:
# рахархивируем архив
!unzip -q /content/drive/MyDrive/GNN/deepmind-research-master.zip

In [None]:
# создаем папки для загрузки датасета
!mkdir -p /tmp/rollous
!mkdir -p /tmp/datasets

In [None]:
# загружаем датасет
!bash /content/deepmind-research-master/learning_to_simulate/download_dataset.sh WaterRamps /tmp/datasets

## Извлекаем данные из TFrecords

In [None]:
# переходим в дерикторию архива 
%cd /content/deepmind-research-master

In [None]:
import functools
import os
import json
import pickle

import tensorflow.compat.v1 as tf
import numpy as np

from learning_to_simulate import reading_utils

In [None]:
# Определяем путь к данным и имя файла
data_path = '/tmp/datasets/WaterRamps'
filename = 'train.tfrecord'

### Metadata

In [None]:
# Читаем metadata
def _read_metadata(data_path):
    with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
        return json.loads(fp.read())

metadata = _read_metadata(data_path)

In [None]:
for key in metadata:
  print(key, metadata[key])

In [None]:
# Определим глобальные переменные
CON_RAD = metadata['default_connectivity_radius'] ** 2
SEQ_LEN = metadata['sequence_length']

### Data

В массиве positons находятся позиции точек. positions.shape = [t_steps_num, nodes_num, dim]

In [None]:
ds_org = tf.data.TFRecordDataset([os.path.join(data_path, filename)])
ds = ds_org.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))

In [None]:
particle_types = []
keys = []
positions = []
for _ds in ds:
    context, features = _ds
    particle_types.append(context["particle_type"].numpy().astype(np.int64))
    keys.append(context["key"].numpy().astype(np.int64))
    positions.append(features["position"].numpy().astype(np.float32))

In [None]:
res_set = set(particle_types[0])
for seq in particle_types:
  cur_set = set(seq)
  res_set = res_set | cur_set
print('Different values in patrical_types:')
print(res_set)

In [None]:
for i in range(len(particle_types)):
  for j in range(len(particle_types[i])):
    if particle_types[i][j] == 5:
      particle_types[i][j] = True
    else:
      particle_types[i][j] = False

In [None]:
print('Shape of each element in positions:')
print(positions[0].shape)

In [None]:
def get_borders_features(positions, borders):
  return np.concatenate(((positions[:, 0]-borders[0][0]).reshape(-1, 1), (borders[0][1]-positions[:, 0]).reshape(-1, 1), 
                         (positions[:, 1]-borders[1][0]).reshape(-1, 1), (borders[1][1]-positions[:, 1]).reshape(-1, 1)), axis=1)

Валидационная выборка

In [None]:
# Определяем путь к данным и имя файла
val_data_path = '/tmp/datasets/WaterRamps'
val_filename = 'valid.tfrecord'

In [None]:
ds_org = tf.data.TFRecordDataset([os.path.join(val_data_path, val_filename)])
ds = ds_org.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))

In [None]:
val_particle_types = []
val_keys = []
val_positions = []
for _ds in ds:
    context, features = _ds
    val_particle_types.append(context["particle_type"].numpy().astype(np.int64))
    val_keys.append(context["key"].numpy().astype(np.int64))
    val_positions.append(features["position"].numpy().astype(np.float32))

In [None]:
for i in range(len(val_particle_types)):
  for j in range(len(val_particle_types[i])):
    if val_particle_types[i][j] == 5:
      val_particle_types[i][j] = True
    else:
      val_particle_types[i][j] = False

# GNN

## Загружаем torch.geometric, определяем библиотеки

In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install torch-geometric

In [None]:
import torch
from torch import Tensor
from torch_geometric.nn import GCNConv, MessagePassing, EdgeConv
from torch_cluster import knn_graph
import torch.nn as nn
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm.notebook import trange
from tqdm import tqdm
from torch.optim import lr_scheduler

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## Message passing net

### CMPNV3

In [None]:
class CMPNV3(MessagePassing):
    def __init__(self, in_channels, hidden_channels, k=10):
        super().__init__(aggr='add')

        self.gnn1_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, hidden_channels//4),
                       nn.BatchNorm1d(hidden_channels//4),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//4, hidden_channels//2),
                       nn.BatchNorm1d(hidden_channels//2),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//2, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )

        self.gnn1 = EdgeConv(nn=self.gnn1_mlp, aggr='add')

        self.gnn2_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        self.gnn2 = EdgeConv(nn=self.gnn2_mlp, aggr='add')

        self.gnn3_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        self.gnn3 = EdgeConv(nn=self.gnn3_mlp, aggr='add')

        self.mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, hidden_channels)
                       )
        
        #self.conv = GCNConv(in_channels, conv_channels)

        self.lin = nn.Sequential(
                       nn.Linear(in_channels+hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, hidden_channels//2),
                       nn.BatchNorm1d(hidden_channels//2),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//2, hidden_channels//4),
                       nn.BatchNorm1d(hidden_channels//4),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//4, 2)
                       )
        
        self.k = k

    def forward(self, x, nodes_mask):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        edge_index = knn_graph(x[:, :2], self.k, loop=True, flow=self.flow)
      
        new_embedding = x + self.gnn1(x, edge_index)
        new_embedding += self.gnn2(new_embedding, edge_index)
        new_embedding += self.gnn3(new_embedding, edge_index)
        #new_embedding += self.gnn4(new_embedding, edge_index)
        #new_embedding += self.gnn5(new_embedding, edge_index)

        
        return self.propagate(edge_index=edge_index, x=new_embedding, nodes_mask=nodes_mask, first_input=x)

    def message(self, x_i, x_j, nodes_mask):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        return self.mlp(torch.cat([x_i, x_i-x_j], dim=1))
      
    def update(self, aggr_out, nodes_mask, first_input):
        # aggr_out has shape [N, latent_size]

        new_embedding = torch.cat([first_input[nodes_mask], aggr_out[nodes_mask]], dim=1)
        new_embedding = self.lin(new_embedding)

        result = first_input[:, :2]
        result[nodes_mask] = new_embedding
        
        return result

In [None]:
net_arch = '''
class CMPNV3(MessagePassing):
    def __init__(self, in_channels, hidden_channels, k=10):
        super().__init__(aggr='add')

        self.gnn1_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, hidden_channels//4),
                       nn.BatchNorm1d(hidden_channels//4),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//4, hidden_channels//2),
                       nn.BatchNorm1d(hidden_channels//2),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//2, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )

        self.gnn1 = EdgeConv(nn=self.gnn1_mlp, aggr='add')

        self.gnn2_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        self.gnn2 = EdgeConv(nn=self.gnn2_mlp, aggr='add')

        self.gnn3_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        self.gnn3 = EdgeConv(nn=self.gnn3_mlp, aggr='add')

        self.gnn4_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        self.gnn4 = EdgeConv(nn=self.gnn4_mlp, aggr='add')

        self.gnn5_mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, in_channels)
                       )
        
        self.gnn5 = EdgeConv(nn=self.gnn5_mlp, aggr='add')

        self.mlp = nn.Sequential(
                       nn.Linear(2*in_channels, 2*hidden_channels),
                       nn.BatchNorm1d(2*hidden_channels),
                       nn.ReLU(),
                       nn.Linear(2*hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, hidden_channels)
                       )
        
        #self.conv = GCNConv(in_channels, conv_channels)

        self.lin = nn.Sequential(
                       nn.Linear(in_channels+hidden_channels, hidden_channels),
                       nn.BatchNorm1d(hidden_channels),
                       nn.ReLU(),
                       nn.Linear(hidden_channels, hidden_channels//2),
                       nn.BatchNorm1d(hidden_channels//2),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//2, hidden_channels//4),
                       nn.BatchNorm1d(hidden_channels//4),
                       nn.ReLU(),
                       nn.Linear(hidden_channels//4, 2)
                       )
        
        self.k = k

    def forward(self, x, nodes_mask):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        edge_index = knn_graph(x[:, :2], self.k, loop=True, flow=self.flow)
      
        new_embedding = x + self.gnn1(x, edge_index)
        new_embedding += self.gnn2(new_embedding, edge_index)
        new_embedding += self.gnn3(new_embedding, edge_index)
        new_embedding += self.gnn4(new_embedding, edge_index)
        new_embedding += self.gnn5(new_embedding, edge_index)

        
        return self.propagate(edge_index=edge_index, x=new_embedding, nodes_mask=nodes_mask, first_input=x)

    def message(self, x_i, x_j, nodes_mask):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        return self.mlp(torch.cat([x_i, x_i-x_j], dim=1))
      
    def update(self, aggr_out, nodes_mask, first_input):
        # aggr_out has shape [N, latent_size]

        new_embedding = torch.cat([first_input[nodes_mask], aggr_out[nodes_mask]], dim=1)
        new_embedding = self.lin(new_embedding)

        result = first_input[:, :2]
        result[nodes_mask] = new_embedding
        
        return result
'''

## Train function

### Train per step

In [None]:
def get_borders_features(positions, borders):
  return np.concatenate(((positions[:, 0]-borders[0][0]).reshape(-1, 1), (borders[0][1]-positions[:, 0]).reshape(-1, 1), 
                         (positions[:, 1]-borders[1][0]).reshape(-1, 1), (borders[1][1]-positions[:, 1]).reshape(-1, 1)), axis=1)

In [None]:
def build_scene(data, particle_types, inds, metadata, input_len, noise_scale):

  i, j = inds

  noise = np.random.normal(loc=1.0, scale=noise_scale, size=data[i][j].shape)

  features = torch.tensor(data[i][j]*noise, dtype=torch.float)

  for l in range(input_len):
    features = torch.cat([features, torch.tensor(data[i][j-l] - data[i][j-l-1])], dim=1)

  features = torch.cat([features, torch.tensor(get_borders_features(data[i][j], metadata['bounds'])), 
                        torch.tensor(particle_types[i]).view(-1, 1)], dim=1)
  
  node_mask = torch.BoolTensor(particle_types[i])
  next_pos = torch.tensor(data[i][j+1])

  return features, node_mask, next_pos

In [None]:
from torch.utils.data import DataLoader
import random
import time

def train_vel(model, optimizer, scheduler, criterion, epochs, train_data, particle_types_train, val_data, particle_types_val):

  model.eval()
  model.load_state_dict(torch.load('/content/drive/MyDrive/GNN/tets_1/state_dict_model.pt'))

  train_losses = []
  val_losses = []

  train_inds = []
  for i in range(len(train_data)):
    for j in range(5, metadata['sequence_length']):
      train_inds.append((i, j))

  val_inds = []
  for i in range(len(val_data)):
    for j in range(5, metadata['sequence_length']):
      val_inds.append((i, j))

  for epoch in range(epochs):

    start_epoch = time.time()

    print('Epoch', epoch+1)

    model.train()
    seq_train_losses = []

    print('Train')

    # train_data.shape = [batches, SEQ_LEN, nodes_num, 2]
    random.shuffle(train_inds)
    for inds in tqdm(train_inds):

      features, node_mask, next_pos = build_scene(train_data, particle_types_train, inds, metadata, 5, 0.0003)

      features = features.to(device)
      node_mask = node_mask.to(device)
      next_pos = next_pos.to(device)

      optimizer.zero_grad()
      y_pred = model(features, node_mask)
      loss = criterion(y_pred[node_mask], next_pos[node_mask])
      seq_train_losses.append(loss.item())
      loss.backward()
      optimizer.step()

    train_losses.append(np.mean(seq_train_losses))
    print('Train loss %f' % train_losses[-1])

    with open("/content/drive/MyDrive/GNN/tets_1/train_losses.txt", "a") as file:
      print(train_losses[-1], file=file)

    # validation
    model.eval()
    seq_val_losses = []

    print('Validation')

    random.shuffle(val_inds)
    with torch.no_grad():

      for inds in tqdm(val_inds):

        features, node_mask, next_pos = build_scene(val_data, particle_types_val, inds, metadata, 5, 0.0003)

        features = features.to(device)
        node_mask = node_mask.to(device)
        next_pos = next_pos.to(device)

        y_pred = model(features, node_mask)
        loss = criterion(y_pred[node_mask], next_pos[node_mask])
        seq_val_losses.append(loss.item())

    scheduler.step()

    val_losses.append(np.mean(seq_val_losses))
    print('Validation loss %f' % val_losses[-1])
    print('')

    with open("/content/drive/MyDrive/GNN/tets_1/val_losses.txt", "a") as file:
      print(val_losses[-1], file=file)

    torch.save(model.state_dict(), '/content/drive/MyDrive/GNN/tets_1/state_dict_model.pt')

    end_epoch = time.time()
    with open("/content/drive/MyDrive/GNN/tets_1/time.txt", "a") as file:
      print(end_epoch-start_epoch, file=file)


  return train_losses, val_losses

## Train

### Data define

In [None]:
train_data = positions
particle_types_train = particle_types

val_data = val_positions
particle_types_val = val_particle_types

### Define CMPNV3

In [None]:
torch.cuda.empty_cache()

In [None]:
cmpnv3 = CMPNV3(in_channels=2 + 2*5 + 4 + 1, hidden_channels=128).to(device)
optimizer = torch.optim.Adam(cmpnv3.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 5, 7], gamma=1)
epochs = 10

In [None]:
#with open("/content/drive/MyDrive/GNN/tets_1/net_info.txt", "w") as file:
#  print(net_arch, file=file)

In [None]:
#total_params = sum(p.numel() for p in cmpnv3.parameters())
#with open("/content/drive/MyDrive/GNN/tets_1/net_info.txt", "a") as file:
#  print('TOTAL_PARAMS:', total_params, file=file)

In [None]:
#with open("/content/drive/MyDrive/GNN/tets_1/net_info.txt", "a") as file:
#  print('LR_strat:', 'Adam', 'lr=1e-3', file=file)

In [None]:
#torch.save(cmpnv3.state_dict(), '/content/drive/MyDrive/GNN/tets_1/state_dict_model.pt')

### Train per step

In [None]:
train_losses, val_losses = train_vel(cmpnv3, optimizer, scheduler, criterion, epochs, train_data, particle_types_train, val_data, particle_types_val)

In [None]:
train_losses, val_losses = train_vel(cmpnv3, optimizer, scheduler, criterion, epochs, train_data, particle_types_train, val_data, particle_types_val)

### Figure

In [None]:
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(1, 1, 1)

ax.plot(range(1, epochs+1), train_losses, color='blue', label='train losses')
ax.plot(range(1, epochs+1), val_losses, color='red', label='validation losses')
ax.set_title('Losses per epoch:')
ax.legend()
plt.show()

# Посмотрим на предсказанную эволюцию 

## Предсказание следующего шага

In [None]:
test_data = positions[1]
particle_types_test = particle_types[1]

In [None]:
torch.tensor(positions[1][0]).shape

In [None]:
torch.tensor(positions[1][0]).shape

In [None]:
edge_index = knn_graph(torch.tensor(positions[1][0]), 15, loop=True)

In [None]:
edge_index.shape

In [None]:
edge_index[-1]

In [None]:
res = np.array(edge_index)
for i in res:
  ans = i[-15:]
  break
print(ans)
point = ans[0]
print(point)

In [None]:
test_data[1].shape

In [None]:
fig = plt.figure(figsize=(24, 12))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
ax1.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

X = [test_data[1][j][0] for j in range(test_data[1].shape[0])]
Y = [test_data[1][j][1] for j in range(test_data[1].shape[0])]

X1 = [test_data[1][j][0] for j in ans]
Y1 = [test_data[1][j][1] for j in ans]

X2 = test_data[1][point][0]
Y2 = test_data[1][point][1]

ax1.scatter(X, Y, color='blue')
ax1.scatter(X1, Y1, color='red')
ax1.scatter(X2, Y2, color='green')

In [None]:
for i in range(1, 100):

  fig = plt.figure(figsize=(24, 12))
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
  ax1.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

  X = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]
  Y = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]

  X_borders = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]
  Y_borders = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]

  edge_index = knn_graph(x[:, :2], self.k, loop=True, flow=self.flow)

  ax1.scatter(X, Y, color='blue')
  ax1.scatter(X_borders, Y_borders, color='red')

In [None]:
for i in range(1, 100):

  fig = plt.figure(figsize=(24, 12))
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
  ax1.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

  X = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]
  Y = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]

  X_borders = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]
  Y_borders = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]

  ax1.scatter(X, Y, color='blue')
  ax1.scatter(X_borders, Y_borders, color='red')

  with torch.no_grad():

    nodes_mask = torch.BoolTensor(particle_types_test).to(device)
    input = torch.cat([torch.tensor(test_data[i-1], device=device), torch.tensor(particle_types_test, device=device).view(-1, 1)], dim=1)

    pred_img = cmpn.forward(input, nodes_mask).cpu()

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
    ax2.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

    X = [pred_img[j][0].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==1]
    Y = [pred_img[j][1].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==1]

    X_borders = [pred_img[j][0].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==0]
    Y_borders = [pred_img[j][1].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==0]
    
    ax2.scatter(X, Y, color='blue')
    ax2.scatter(X_borders, Y_borders, color='red')

    clear_output(wait=True)
    plt.show()

## Предсказания на основе предыдущих

In [None]:
test_data = positions[53]
particle_types_test = particle_types[53]

In [None]:
last_out = torch.tensor(test_data[0], device=device)
last_out = torch.cat((last_out, torch.tensor(particle_types_test, device=device).view(-1, 1)), dim=1)

for i in range(100):

  fig = plt.figure(figsize=(24, 12))
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
  ax1.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

  X = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]
  Y = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==1]

  X_borders = [test_data[i][j][0] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]
  Y_borders = [test_data[i][j][1] for j in range(test_data[i].shape[0]) if particle_types_test[j]==0]

  ax1.scatter(X, Y, color='blue')
  ax1.scatter(X_borders, Y_borders, color='red')

  with torch.no_grad():
    pred_img = cmpn.forward(last_out, torch.BoolTensor(particle_types_test)).cpu()

    last_out = torch.cat((pred_img.to(device), torch.tensor(particle_types_test, device=device).view(-1, 1)), dim=1)

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.set_xlim(metadata['bounds'][0][0], metadata['bounds'][0][1])
    ax2.set_ylim(metadata['bounds'][1][0], metadata['bounds'][1][1])

    X = [pred_img[j][0].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==1]
    Y = [pred_img[j][1].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==1]

    X_borders = [pred_img[j][0].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==0]
    Y_borders = [pred_img[j][1].item() for j in range(pred_img.shape[0]) if particle_types_test[j]==0]
    
    ax2.scatter(X, Y, color='blue')
    ax2.scatter(X_borders, Y_borders, color='red')

    clear_output(wait=True)
    plt.show()