<a href="https://colab.research.google.com/github/bpratham2001/Bridge/blob/main/ML4SCI/mass_regeressor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mass Regression

Please train a model to estimate (regress) the mass of the particle based on particle images using the provided dataset.

DataSet Description: 125x125 image matrices with name of variables: ieta and iphi, with 4 channels called X_jet (Track pT, DZ and D0, ECAL). Please use at least ECAL and Track pT channels and ‘am‘ as the target feature. If there are more than 4 channels in the dataset then you should use X_jet (Track pT, DZ and D0, ECAL) only. Please train your model on 80% of the data and evaluate on the remaining 20%. Please make sure not to overfit on the test dataset - it will be checked with an independent sample.

Datasets: https://cernbox.cern.ch/s/zUvpkKhXIp0MJ0g

**1,051,917 rows in total**

In [1]:
# @title !wget .parquet files

#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_0.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_1.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_2.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_3.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_4.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_5.parquet
#!wget https://cernbox.cern.ch/remote.php/dav/public-files/zUvpkKhXIp0MJ0g/top_gun_opendata_6.parquet

parqs = ['top_gun_opendata_0.parquet', 'top_gun_opendata_1.parquet', 'top_gun_opendata_2.parquet',
         'top_gun_opendata_3.parquet', 'top_gun_opendata_4.parquet', 'top_gun_opendata_5.parquet',
         'top_gun_opendata_6.parquet']
folder = '/content/drive/My Drive/top_gun/'

In [2]:
# @title imports and installs

#import pyspark
#from pyspark.sql import SparkSession
#import pandas as pd
from IPython.display import clear_output
import random
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import datetime
from math import sqrt
#from sklearn.metrics import root_mean_squared_error

from google.colab import drive
#drive.mount('/content/drive')

In [3]:
# @title Model Architecture (ResNet)

class ResBlock(nn.Module):
  # This is the residual block, there are 4 residual blocks in ResNet15
  def __init__(self, in_channels, out_channels, stride=1):
    super(ResBlock, self).__init__()
    self.relu = nn.ReLU(inplace=True)
    self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0, bias=False)
    nn.init.kaiming_normal_(self.conv1.weight)
    self.bn1 = nn.BatchNorm2d(in_channels)

    self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False)
    nn.init.kaiming_normal_(self.conv2.weight)
    self.bn2 = nn.BatchNorm2d(in_channels)

    self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
    nn.init.kaiming_normal_(self.conv3.weight)
    self.bn3 = nn.BatchNorm2d(out_channels)

    self.shortcut = nn.Sequential()

    if stride != 1 or in_channels != out_channels:
      shortcut_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
      nn.init.kaiming_normal_(shortcut_conv.weight)
      self.shortcut.append(shortcut_conv)
      self.shortcut.append(nn.BatchNorm2d(out_channels))

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)
    out = self.conv3(out)
    out = self.bn3(out)
    out += self.shortcut(x)
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self, num_classes=2):
    super(ResNet, self).__init__()
    self.conv1 = nn.Conv2d(4, 16, kernel_size=1, stride=1, padding=0, bias=False)
    nn.init.kaiming_normal_(self.conv1.weight)
    self.bn1 = nn.BatchNorm2d(16)
    self.relu = nn.ReLU(inplace=True)

    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

    self.res1 = ResBlock(16, 32)
    self.res2 = ResBlock(32, 64)

    #self.conv2 = nn.Conv2d(128, 256, kernel_size=1, stride=1, padding=0, bias=False)
    #nn.init.kaiming_normal_(self.conv2.weight)
    #self.bn2 = nn.BatchNorm2d(256)

    self.res3 = ResBlock(64, 128)
    self.res4 = ResBlock(128, 64)

    self.fc1 = nn.Linear(64 * 125 * 125, 1)
    nn.init.kaiming_normal_(self.fc1.weight)
    #self.fc2 = nn.Linear(16 * 125 * 125, 1)
    #nn.init.kaiming_normal_(self.fc2.weight)

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.maxpool(out)
    out = self.res1(out)
    out = self.res2(out)
    #out = self.conv2(out)
    #out = self.bn2(out)
    out = self.res3(out)
    out = self.res4(out)
    out = torch.flatten(out, 1)
    out = self.fc1(out)
    #out = self.fc2(out)
    out = self.relu(out)
    return out

In [22]:
# @title dataset metadata and helper functions

def get_parquet_rows(parq, start=0, end=-1, drive=False):
  # since rows == row_groups in the dataset, we can use them interchangeably
  lst = []
  if drive:
    parq_file = pq.ParquetFile(folder+parq)
  else:
    parq_file = pq.ParquetFile(parq)
  if end < 0:
    end = parq_file.scan_contents()
  while start < end:
    lst.append(parq_file.read_row_group(start))
    start += 1
  return lst

def make_X(X_jet):
  X_set = np.zeros(shape=(4, 125, 125))
  for i in range(4): # only selecting the first 4 channels (Track pT, DZ, D0, ECAL)
    for j in range(125):
      for k in range(125):
        X_set[i][j][k] = X_jet[i][j][k]
  return torch.tensor(X_set, dtype=torch.float32)


def get_Xy(parq, start=0, end=-1, drive=False):
  arr = np.array(get_parquet_rows(parq, start, end, drive))
  X = []
  y = []
  for i in range(len(arr)):
    X.append(make_X(np.array(arr[i][0][0])))
    y.append(arr[i][0][1])
  return np.array(X), np.array(y)

""" # row counter
tot = 0
for parq in parqs:
  num = pq.ParquetFile(folder+parq).scan_contents()
  tot += num
  print(parq+" contains "+str(num)+" rows")
print("total " + str(tot))
del tot, num

top_gun_opendata_0.parquet contains 150327 rows
top_gun_opendata_1.parquet contains 150165 rows
top_gun_opendata_2.parquet contains 150451 rows
top_gun_opendata_3.parquet contains 150448 rows
top_gun_opendata_4.parquet contains 150557 rows
top_gun_opendata_5.parquet contains 150056 rows
top_gun_opendata_6.parquet contains 149913 rows
total 1051917
"""

for parq in parqs:
  print(pq.read_metadata(folder+parq))
  #print(pq.ParquetFile(folder+parq).read_row_group(0))
  break

class Deck():
  def __init__(self, filenames, filesizes):
    self.filenames = filenames
    self.filesizes = filesizes
    self.deck = []

  def create_deck(self, max_chunk):
    self.deck = []
    for i in range(len(self.filenames)):
      start = 0
      while self.filesizes[i] > 0:
        if max_chunk >= self.filesizes[i]:
          self.deck.append([self.filenames[i], start, (start+self.filesizes[i])-1])
          self.filesizes[i] = 0
        else:
          self.deck.append([self.filenames[i], start, start+max_chunk])
          self.filesizes[i] -= max_chunk
          start += max_chunk

  def shuffle_deck(self):
    random.shuffle(self.deck)

  def deal_cards(self, split):
    return self.deck[:int(len(self.deck)*split)], self.deck[int(len(self.deck)*split):]

  def play_card(self, card, batch_size, drive=False, reshuffle=True):
    if drive:
      parq_file = pq.ParquetFile(folder+card[0])
    else:
      parq_file = pq.ParquetFile(card[0])

    df = parq_file.read_row_groups(np.arange(card[1], card[2]), columns=['X_jet', 'm']).to_pandas()
    df['X_jet'] = df['X_jet'].map(lambda x: make_X(x))

    torch.cuda.empty_cache()
    X_tensor = torch.tensor(df['X_jet'].values, dtype=torch.float32)
    y_tensor = torch.tensor(df['m'].values, dtype=torch.float32)
    X_tensor = X_tensor.to(device)
    y_tensor = y_tensor.to(device)
    return DataLoader(TensorDataset(X_tensor, y_tensor), batch_size=batch_size, shuffle=reshuffle)

<pyarrow._parquet.FileMetaData object at 0x7cb3581d5da0>
  created_by: parquet-cpp version 1.5.1-SNAPSHOT
  num_columns: 5
  num_rows: 150327
  num_row_groups: 150327
  format_version: 1.0
  serialized_size: 72239212


In [5]:
# @title Tuning and Loading the Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###############################################################################
epochs = 1
split = 0.8
max_chunk = 4000
batch_size = 200
optimiseur = ['Adam', 0.01] # ['Adam, learning rate'] / ['SGD', learning rate, momentum]
criterion = nn.MSELoss().to(device)
path = '' # 'mass_regressor.pth'
###############################################################################

model = ResNet()

if torch.cuda.device_count() > 1:
  print("Using ", torch.cuda.device_count(), "GPUs")
  model = nn.DataParallel(model)

if path != '':
  model.load_state_dict(torch.load(f=path))

model = model.to(device)

if optimiseur[0] == 'Adam':
  optimiseur = torch.optim.Adam(model.parameters(),lr=optimiseur[1])
elif optimiseur[0] == 'SGD':
  optimiseur = torch.optim.SGD(model.parameters(), lr=optimiseur[1], momentum=optimiseur[2])

In [24]:
# @title Training

print("Training started at", datetime.datetime.now())
parq_len = [150327, 150165, 150451, 150448, 150557, 150056, 149913]
deck = Deck(parqs, parq_len)
deck.create_deck(max_chunk)
print(deck.deck)
deck.shuffle_deck()
train_cards, test_cards = deck.deal_cards(split)
print(train_cards)
plotloss = []
flag = False

model.train()
for epoch in range(epochs):

  if (epochs > 5) and (epoch > (epochs*0.5)):
    optimiseur = torch.optim.Adam(model.parameters(),lr=0.001)
  elif (epochs > 5) and (epoch > (epochs*0.75)):
    optimiseur = torch.optim.Adam(model.parameters(),lr=0.0001)

  running_loss = 0.0 # MSE, per epoch

  for card in train_cards:
    acc = [] # RMSE, per card per epoch
    for i, (inputs, labels) in enumerate(deck.play_card(card, batch_size, drive=True, reshuffle=False)):
      outputs = model(inputs)
      if flag:
        print(outputs) #torch.flatten(torch.rot90(outputs)))
        print(labels) #torch.flatten(labels))
        flag = False
      loss = criterion(outputs, labels.unsqueeze(1))
      optimiseur.zero_grad()
      loss.backward()
      optimiseur.step()
      running_loss += loss.item()
      acc.append(sqrt(loss))
    print(f"RMSE: {sum(acc)/len(acc)}")
  plotloss.append(running_loss)
  print(f"E [{epoch + 1}/{epochs}], L: {running_loss:.4f}, T: [{datetime.datetime.now()}]")

print("Training completed at", datetime.datetime.now())

Training started at 2025-03-24 16:32:01.054055
[['top_gun_opendata_0.parquet', 0, 4000], ['top_gun_opendata_0.parquet', 4000, 8000], ['top_gun_opendata_0.parquet', 8000, 12000], ['top_gun_opendata_0.parquet', 12000, 16000], ['top_gun_opendata_0.parquet', 16000, 20000], ['top_gun_opendata_0.parquet', 20000, 24000], ['top_gun_opendata_0.parquet', 24000, 28000], ['top_gun_opendata_0.parquet', 28000, 32000], ['top_gun_opendata_0.parquet', 32000, 36000], ['top_gun_opendata_0.parquet', 36000, 40000], ['top_gun_opendata_0.parquet', 40000, 44000], ['top_gun_opendata_0.parquet', 44000, 48000], ['top_gun_opendata_0.parquet', 48000, 52000], ['top_gun_opendata_0.parquet', 52000, 56000], ['top_gun_opendata_0.parquet', 56000, 60000], ['top_gun_opendata_0.parquet', 60000, 64000], ['top_gun_opendata_0.parquet', 64000, 68000], ['top_gun_opendata_0.parquet', 68000, 72000], ['top_gun_opendata_0.parquet', 72000, 76000], ['top_gun_opendata_0.parquet', 76000, 80000], ['top_gun_opendata_0.parquet', 80000, 84

KeyboardInterrupt: 

In [None]:
import seaborn as sns
sns.lineplot(x=np.array(list(range(epochs))), y=np.array(plotloss))

In [None]:
# @title Testing

print("Testing started at", datetime.datetime.now())

val_loss = 0.0
acc = []
model.eval()
with torch.no_grad():
  for card in test_cards:
    for i, (inputs, labels) in enumerate(deck.play_card(card, batch_size, drive=True, reshuffle=False)):
      outputs = model(inputs)
      loss = criterion(outputs, labels.unsqueeze(1))
      val_loss += loss.item()
      acc.append(sqrt(loss))
    print(sqrt(loss))

print(f"avg RMSE: {sum(acc)/len(acc)}")
print("Testing completed at", datetime.datetime.now())

In [None]:
# @title Saving
torch.save(obj=model.state_dict(), f="resnet_regressor.pth")

# Tensorflow stuffs to ignore

In [None]:
# @title imports and installs
!pip install tensorflow==2.18.0
!pip install tensorflow-tpu==2.18.0 --find-links=https://storage.googleapis.com/libtpu-tf-releases/index.html
#import pyspark
#from pyspark.sql import SparkSession
#import pandas as pd
from IPython.display import clear_output
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import os
import tensorflow as tf
import tensorflow.keras as keras
from keras.callbacks import EarlyStopping
from keras.layers import Dense, Conv2D,  MaxPool2D, Flatten, GlobalAveragePooling2D,  BatchNormalization, Layer, Add
from keras.models import Sequential
from keras.models import Model
#import tensorflow_datasets as tfds
import datetime
from math import sqrt
#from sklearn.metrics import root_mean_squared_error

from google.colab import drive
#drive.mount('/content/drive')

In [3]:
# @title Tuning
#TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
strategy = tf.distribute.TPUStrategy(resolver)

#optimiseur = keras.optimizers.SGD(learning_rate=0.1,momentum=0.9,decay = 1e-04)
optimiseur = keras.optimizers.Adam(learning_rate=0.01)
criterion = keras.losses.MeanSquaredError()
batch_size_parq = 5000
batch_size_train = 256
steps_per_epoch = 50000 // batch_size_train
exec_step = 50 # between 2 and steps_per_epoch
epochs = 10

All devices:  [LogicalDevice(name='/device:TPU:0', device_type='TPU'), LogicalDevice(name='/device:TPU:1', device_type='TPU'), LogicalDevice(name='/device:TPU:2', device_type='TPU'), LogicalDevice(name='/device:TPU:3', device_type='TPU'), LogicalDevice(name='/device:TPU:4', device_type='TPU'), LogicalDevice(name='/device:TPU:5', device_type='TPU'), LogicalDevice(name='/device:TPU:6', device_type='TPU'), LogicalDevice(name='/device:TPU:7', device_type='TPU')]


In [None]:
# @title dataset metadata and helper functions

def get_parquet_rows(parq, start=0, end=-1, drive=False):
  # since rows == row_groups in the dataset, we can use them interchangeably
  lst = []
  if drive:
    parq_file = pq.ParquetFile(folder+parq)
  else:
    parq_file = pq.ParquetFile(parq)
  if end < 0:
    end = parq_file.scan_contents()
  while start < end:
    lst.append(parq_file.read_row_group(start))
    start += 1
  return lst

def make_X(X_jet):
  X_set = np.zeros(shape=(4, 125, 125))
  for i in range(4): # only selecting the first 4 channels (Track pT, DZ, D0, ECAL)
    for j in range(125):
      for k in range(125):
        X_set[i][j][k] = X_jet[i][j][k]
  return tf.convert_to_tensor(X_set, tf.float32)

def make_Xs(X_jets):
  X_sets = np.zeros(shape=(len(X_jets), 4, 125, 125))
  for i in range(len(X_jets)):
    for j in range(4): # only selecting the first 4 channels (Track pT, DZ, D0, ECAL)
      for k in range(125):
        for l in range(125):
          X_sets[i][j][k][l] = X_jets[i][j][k][l]
  return tf.convert_to_tensor(X_sets, tf.float32)


def get_Xy(parq, start=0, end=-1, drive=False):
  arr = np.array(get_parquet_rows(parq, start, end, drive))
  X = []
  y = []
  for i in range(len(arr)):
    X.append(make_X(np.array(arr[i][0][0])))
    y.append(arr[i][0][1])
  return np.array(X), np.array(y)

""" # row counter
tot = 0
for parq in parqs:
  num = pq.ParquetFile(folder+parq).scan_contents()
  tot += num
  print(parq+" contains "+str(num)+" rows")
print("total " + str(tot))
del tot, num

top_gun_opendata_0.parquet contains 150327 rows
top_gun_opendata_1.parquet contains 150165 rows
top_gun_opendata_2.parquet contains 150451 rows
top_gun_opendata_3.parquet contains 150448 rows
top_gun_opendata_4.parquet contains 150557 rows
top_gun_opendata_5.parquet contains 150056 rows
top_gun_opendata_6.parquet contains 149913 rows
total 1051917
"""
parq_len = [150327, 150165, 150451, 150448, 150557, 150056, 149913]

for parq in parqs:
  print(pq.read_metadata(folder+parq))
  #print(pq.ParquetFile(folder+parq).read_row_group(0))
  break

"""
print(datetime.datetime.now())
parq_file = pq.ParquetFile(folder+parqs[0])
for i in parq_file.iter_batches(batch_size=4000, columns=['X_jet', 'm']):
  df = i.to_pandas()
  df['X_jet'] = df['X_jet'].map(lambda x: make_X(x))
  break
print(df['X_jet'].iloc[0].shape)
print(datetime.datetime.now())
"""

In [None]:
# @title tf_recordwriter

def read_tfrecord(example):
  features={
    "X": tf.io.FixedLenFeature([], tf.string),
    "y": tf.io.FixedLenFeature([], tf.float32),
    }
  example = tf.io.parse_single_example(example, features)
  #X = make_X(example["X_track_pT", "X_DZ", "X_D0", "X_ECAL"])
  X = tf.io.parse_tensor(example['X'], tf.float32)
  y = tf.cast(example['y'], tf.float32)
  return X, y

def get_datasets(tfRecords, batch_size, shuffle=5000, multiple=False):
  #For optimal performance, read multiple TFRecord files and set option experimental_deterministic = False
  dataset = tf.data.TFRecordDataset(tfRecords, num_parallel_reads=tf.data.experimental.AUTOTUNE)
  option = tf.data.Options()
  if multiple:
    option.experimental_deterministic = False
  dataset = dataset.with_options(option)
  dataset = dataset.map(read_tfrecord)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.shuffle(shuffle)
  dataset = dataset.batch(batch_size)
  return dataset

with tf.io.TFRecordWriter("top_gun.tfrecords") as file_writer:
  print("Writing started at", datetime.datetime.now())
  for i in range(len(parqs)):
    for j in range(parq_len[i]):
      X, y = get_Xy(parqs[i], start=j, end=(j+4000), drive=True)
      X_serialized = tf.io.serialize_tensor(X)
      j += 5000
      features = {
          "X": tf.train.Feature(bytes_list=tf.train.BytesList(value=[X_serialized.numpy()])),
          "y": tf.train.Feature(float_list=tf.train.FloatList(value=y)),
          }
      features = tf.train.Features(feature=features)
      example = tf.train.Example(features=features)
      record_bytes = example.SerializeToString()
      file_writer.write(record_bytes)
      print(datetime.datetime.now())
    print(f"parq {parq[i]} complete")
  print("Writing completed at", datetime.datetime.now())

In [10]:
# @title ResNet18 Architecture
"""
ResNet-18
Reference:
[1] K. He et al. Deep Residual Learning for Image Recognition. CVPR, 2016
[2] K. He, X. Zhang, S. Ren, and J. Sun. Delving deep into rectifiers:
Surpassing human-level performance on imagenet classification. In
ICCV, 2015.
"""

class ResBlock(Model):
  def __init__(self, channels, down_sample=False):
    super().__init__()
    self.__channels = channels
    self.__down_sample = down_sample
    self.__strides = [2, 1] if down_sample else [1, 1]

    self.conv1 = Conv2D(self.__channels, strides=self.__strides[0], kernel_size=(3, 3), padding="same", kernel_initializer="he_normal")
    self.bn1 = BatchNormalization()
    self.conv2 = Conv2D(self.__channels * 2, strides=self.__strides[1], kernel_size=(3, 3), padding="same", kernel_initializer="he_normal")
    self.bn2 = BatchNormalization()
    self.merge = Add()
    if self.__down_sample:
      self.res_conv = Conv2D(self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer="he_normal", padding="same")
      self.res_bn = BatchNormalization()

  def call(self, inputs):
    res = inputs
    x = self.conv1(inputs)
    x = self.bn1(x)
    x = tf.nn.relu(x)
    x = self.conv2(x)
    x = self.bn2(x)
    if self.__down_sample:
      res = self.res_conv(res)
      res = self.res_bn(res)
    x = self.merge([x, res])
    out = tf.nn.relu(x)
    return out


class ResNet(Model):
  def __init__(self, **kwargs):
    """num_classes: number of classes in specific classification task."""
    super().__init__(**kwargs)
    self.conv1 = Conv2D(64, (7, 7), strides=2, padding="same", kernel_initializer="he_normal")
    self.bn = BatchNormalization()

    self.maxpool = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")

    self.res1_1 = ResBlock(64)
    self.res1_2 = ResBlock(64)

    self.res2_1 = ResBlock(128, down_sample=True)
    self.res2_2 = ResBlock(128)

    self.res3_1 = ResBlock(256, down_sample=True)
    self.res3_2 = ResBlock(256)

    self.res4_1 = ResBlock(512, down_sample=True)
    self.res4_2 = ResBlock(512)

    self.avgpool = GlobalAveragePooling2D()

    self.fc1 = Dense(1, activation="relu")

  def call(self, inputs):
    out = self.conv1(inputs)
    out = self.bn(out)
    out = tf.nn.relu(out)
    out = self.maxpool(out)
    for res_block in [self.res1_1, self.res1_2, self.res2_1, self.res2_2, self.res3_1, self.res3_2, self.res4_1, self.res4_2]:
      out = res_block(out)
    out = self.avgpool(out)
    out = Flatten(out)
    out = self.fc1(out)
    return out

with strategy.scope():
  model = ResNet()
  model.build(input_shape = (None, 125, 125, 4))
  model.compile(optimizer=optimiseur, steps_per_execution=exec_step, loss=criterion, metrics=["mse"])
"""
model = ResNet()
tpu_model = tf.compat.v1.estimator.tpu.keras_to_tpu_model(
    model,
    strategy=strategy)
with strategy.scope():
  tpu_model.build(input_shape = (None, 125, 125, 4))
  tpu_model.compile(optimizer=optimiseur, steps_per_execution=exec_step, loss=criterion, metrics=["mse"])
"""
model.summary()

In [None]:
# @title tensor loop
# print(f"E [{epoch + 1}/{epochs}], L: {running_loss:.4f}, T: [{datetime.datetime.now()}]")
#steps_per_epoch=steps_per_epoch
batch_size_parq=50
print("Training started at", datetime.datetime.now())
plotloss = []
for epoch in range(epochs):
  for parq in parqs:
    parq_file = pq.ParquetFile(folder+parqs[0])
    for i in parq_file.iter_batches(batch_size=batch_size_parq, columns=['X_jet', 'm']):
      df = i.to_pandas()
      X = df['X_jet'].values
      y = df['m'].values
      X = make_Xs(X)
      print(X.shape)
      history = model.fit(x=X, y=y, batch_size=batch_size_train, epochs=1, steps_per_epoch=steps_per_epoch, shuffle=True)
      plotloss.append(sqrt(history.history['mse']))
      print(datetime.datetime.now())
print("Training completed at", datetime.datetime.now())
print(f"Test RMSE: [{sqrt(model.evaluate(x=X_test, y=Y_test, batch_size=batch_size))}], T: [{datetime.datetime.now()}]")