<h1>CNN for Multiple chunks of data</h1>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

import random
import numpy as np
torch.manual_seed(101101)
random.seed(101101)
np.random.seed(101101)

  warn(f"Failed to load image Python extension: {e}")


In [2]:
import os
# os.environ['CUDA_VISIBLE_DEVICES']

In [3]:
print(torch.__version__)

1.11.0


In [4]:
print(torch.cuda.is_available())

True


In [5]:
torch.cuda.device_count()

1

In [6]:
import subprocess
subprocess.run(["python", "--version"])

Python 3.10.9


CompletedProcess(args=['python', '--version'], returncode=0)

In [7]:
import torchvision

In [8]:
print(torch.backends.cudnn.enabled)

True


In [9]:
subprocess.run(['nvcc','--version'])

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0


CompletedProcess(args=['nvcc', '--version'], returncode=0)

In [10]:
subprocess.run(['conda','list'])

# packages in environment at /ext3/miniconda3/envs/cnnenv:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
alsa-lib                  1.2.8                h166bdaf_0    conda-forge
anyio                     3.6.2              pyhd8ed1ab_0    conda-forge
aom                       3.5.0                h27087fc_0    conda-forge
argon2-cffi               21.3.0             pyhd8ed1ab_0    conda-forge
argon2-cffi-bindings      21.2.0          py310h5764c6d_3    conda-forge
astropy                   5.2.1           py310h0a54255_0    conda-forge
asttokens                 2.2.1              pyhd8ed1ab_0    conda-forge
attr                      2.5.1                h166bdaf_1    conda-forge
attrs                     22.2.0             pyh71513ae_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-fo

CompletedProcess(args=['conda', 'list'], returncode=0)

<h2>Model definition</h2>

Higher stride at higher layers,

In [77]:
class DownSizeNet(nn.Module):
    def __init__(self, in_size, out_size,stride=1, padding=1, leaky_slope=0.2):
        super(DownSizeNet, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 3, stride=stride, padding=padding, bias=False).double()]
        layers.append(nn.MaxPool2d(3, stride=stride).double())
        layers.append(nn.LeakyReLU(leaky_slope).double())
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        
        return self.model(x)
    
class ChunkNet(nn.Module):
    def __init__(self,s_in,s_h,s_out):
        super(ChunkNet, self).__init__()
        self.dense1 = nn.Sequential(nn.Linear(s_in,s_h), nn.ReLU()).double()
        self.dense2 = nn.Sequential(nn.Linear(s_h,s_out), nn.ReLU()).double()
        
        self.final = nn.Linear(s_out,1).double()
        
    def forward(self,x):
        d1 = self.dense1(x)
        d2 = self.dense2(d1)
        
        return self.final(d2)

    
class FeatureNet(nn.Module):
    def __init__(self):
        super(FeatureNet, self).__init__()
        self.down1 = DownSizeNet(3,    64,stride=1)
        self.down2 = DownSizeNet(64,  128,stride=1,padding=0)
        self.down3 = DownSizeNet(128, 256,stride=1,padding=0)
        self.down4 = DownSizeNet(256, 256,stride=2,padding=0)
#         self.down5 = DownSizeNet(256, 256,stride=2)
#         self.down6 = DownSizeNet(256, 256,stride=2)
        
        
    def forward(self, x):
        # Propogate noise through fc layer and reshape to img shape
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
#         d5 = self.down5(d4)
#         d6 = self.down6(d5)
        
        return d4
    
class RV_Model(nn.Module):
    def __init__(self,s_c,s_in,s_h,s_out,device):
        super(RV_Model, self).__init__()
#         self.chunk_models = []
        
        self.chunk_models = nn.ModuleList([ChunkNet(s_in,s_h,s_out).double().to(device) for i in range(s_c)]).to(device)
#         for i in range(s_c):
#             self.chunk_models.append()
        self.feature_model = FeatureNet().double().to(device)
        
    def forward(self,x,indices):
        y1 = self.feature_model(x)
        y1 = torch.flatten(y1,1)
        y2 = torch.empty(y1.shape[0])
#         print(indices.shape)
        for i,index in enumerate(indices):
            temp = int(index.item())
#             print('index', temp)
            y2[i] = self.chunk_models[temp](y1[i,:])
#             y2[index] = self.chunk_model[index](y1[index,...])
        return y2

In [63]:
from torch.utils.data import Dataset
import glob

In [13]:
from astropy.io import fits

<h2>Dataset importing</h2>
The question here is how is the data organized in the directory and how can it be imported with the target RV. 

Not the OG data but after the data is saved from the pre processing step.

In [14]:
import pickle

In [15]:
import numpy as np
import h5py as h5

256 by 17 with 16 overlap should be the best size cropping. shit shit shit
0-256 240-496 480-736 720-976 
what is wrong with me im dumb
so dumb dumb dumb

In [16]:
import pickle
def save(filename,model):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(model, output, pickle.HIGHEST_PROTOCOL)

def load(filename):
    with open(filename, 'rb') as input:  # Overwrites any existing file.
        model = pickle.load(input)
        return model

In [17]:
import os.path as path

In [18]:
files1 = glob.glob('/scratch/mdd423/CNN_EPRV/data/peg51_256/*.h5')

crop median and sift through h5s here. then pickle as nd arrays

In [19]:
all_directories = glob.glob('/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/*02-28*')
files = []
for indiv in all_directories:
    files += glob.glob(indiv + '/*.h5')

In [20]:
files

['/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1680-1936.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_960-1216_0-256.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2880-3136.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1440-1696.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2160-2416.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2640-2896.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_480-736.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_960-1216.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2400-2656.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_3120-3376.h5',
 '/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/

In [21]:
def h5_to_array(ds,target,location,hdu_num='hdu_1'):
    rvs_stack = []
    bcs_stack = []
    tim_stack = []
    
    img_stack = np.empty((0,3,256,256))
    for visit_name,vist_info in ds['visits'].items():
#         for hdu_num in ds['images'][visit_name].keys():
        img   = np.array(ds['images'][visit_name][hdu_num])
#         
        if np.sum(img.shape) == (256*2):
            rvs_stack += [np.double(vist_info.attrs['ESO DRS CCF RVC'])]
            temp_time = at.Time(visit_name.split('HARPS.')[1])
            tim_stack += [temp_time]
            bcs_stack += [target.radial_velocity_correction(obstime=temp_time, location=location).to('km/s').value]
            all_flats = np.empty([0,256,256])
            for key in vist_info.attrs.keys():
                if vist_info.attrs[key] == 'FLAT':
                    temp = np.array(ds['images'][key][hdu_num])[None,...]
                    all_flats = np.append(all_flats, temp,axis=0)
                if vist_info.attrs[key] == 'THAR_THAR':
                    cali = np.array(ds['images'][key][hdu_num])
#                 temp = crop_again(np.stack((img,np.median(all_flats,axis=0),cali)))
            temp = np.stack((img,np.median(all_flats,axis=0),cali))[None,...]
#                 print(temp.shape)
            img_stack = np.append(img_stack,temp,axis=0)
    return img_stack, np.array(rvs_stack), np.array(bcs_stack), np.array(tim_stack)

ok so now loop through all h5 files convert to nda arrays so they can later be all load back in. append only after saving. load then stack into one array with an index array that refers an image to its chunk and hdu

In [22]:
import astropy.coordinates as coords
import astropy.time as at

In [23]:
location = coords.EarthLocation.of_site('La Silla Observatory')
target   = coords.SkyCoord.from_name('51PEG')

iterations = 0
img_stack = np.empty((0,3,256,256))
rvs_stack = np.empty((0))
bcs_stack = np.empty((0))
tim_stack = np.empty((0))
ind_stack = np.empty((0))
for filename in files:
    for hdu in ['hdu_1','hdu_2']:
#     filename             = files1[0]

#         ds                   = h5.File(filename,'r')
#         img_stack, rvs_stack, bcs_stack, tim_stack = h5_to_array(ds,target,location,hdu_num=hdu)
        dir_name, tailname = path.split(filename)
        tailname = tailname[:-3]
        imgname = path.join(dir_name, tailname + hdu + '_img.nda')
        rvsname = path.join(dir_name, tailname + hdu + '_rvs.nda')
        bcsname = path.join(dir_name, tailname + hdu + '_bcs.nda')
        timname = path.join(dir_name, tailname + hdu + '_tim.nda')
        print(imgname,rvsname)
#         save(imgname,img_stack)
#         save(rvsname,rvs_stack)
#         save(bcsname,bcs_stack)
#         save(timname,tim_stack)
        img_stack = np.append(img_stack, load(imgname),axis=0)
        rvs_stack = np.append(rvs_stack, load(rvsname),axis=0)
        bcs_stack = np.append(bcs_stack, load(bcsname),axis=0)
        tim_stack = np.append(tim_stack, load(timname),axis=0)
        
        ind_stack = np.append(ind_stack, iterations*np.ones(bcs_stack.shape,dtype=int))
        iterations += 1

/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1680-1936hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1680-1936hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1680-1936hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_720-976_1680-1936hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_960-1216_0-256hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_960-1216_0-256hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_960-1216_0-256hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_960-1216_0-256hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2880-3136hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28/51Peg_240-496_2880-3136hdu_1_

/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_720-976_3600-3856hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_720-976_3600-3856hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_960-1216_2400-2656hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_960-1216_2400-2656hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_960-1216_2400-2656hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_960-1216_2400-2656hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_480-736_2640-2896hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_480-736_2640-2896hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-4/51Peg_480-736_2640-2896hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-0

/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_1680-1936hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_1680-1936hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_960-1216hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_960-1216hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_960-1216hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_960-1216hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_1440-1696hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_1440-1696hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/51Peg_240-496_1440-1696hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-6/5

/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_720-976_240-496hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_720-976_240-496hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_720-976_240-496hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_720-976_240-496hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_480-736_1200-1456hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_480-736_1200-1456hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_480-736_1200-1456hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_480-736_1200-1456hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg_480-736_1440-1696hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-3/51Peg

/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_0-256_2640-2896hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_0-256_2640-2896hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_0-256_2640-2896hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_0-256_2640-2896hdu_2_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_720-976_1200-1456hdu_1_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_720-976_1200-1456hdu_1_rvs.nda
/scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_720-976_1200-1456hdu_2_img.nda /scratch/mdd423/CNN_EPRV/data/peg51_256/raw/peg51_256/2023-02-28-2/51Peg_720-976_1200-1456hdu_2_rvs.nda


In [24]:
class ND_Dataset(Dataset):
    def __init__(self, imgs,rvs,indices):
        self.img_stack = imgs
        self.rvs_stack = rvs
        self.indices   = indices
        self.type      = torch.Tensor

    def __getitem__(self, index):
        
        return {'img': self.type(self.img_stack[index,...]).double(),
                'rvs': np.double(self.rvs_stack[index]), 
                'indices': self.indices[index]}
    
    
    def __len__(self):
        
        return len(self.rvs_stack)

In [25]:
# files1 = glob.glob('/scratch/mdd423/CNN_EPRV/data/HARPS/tryagain/51Peg-3/*.h5')
# files2 = glob.glob('/scratch/mdd423/CNN_EPRV/data/HARPS/PEG51/51Peg-2/*.h5')
# files3 = glob.glob('/scratch/mdd423/CNN_EPRV/data/HARPS/PEG51/51Peg-3/*.h5')
# files4 = glob.glob('/scratch/mdd423/CNN_EPRV/data/HARPS/PEG51/51Peg-4/*.h5')
# file_list = files1 + files2 + files3 + files4

In [26]:
dataset = ND_Dataset(img_stack,bcs_stack,ind_stack)

In [28]:
len(dataset)

24880

In [29]:
testdata,validdata = torch.utils.data.random_split(dataset,[23880,1000])

In [31]:
import os.path

<h2>Defining Fitting Process</h2>
including hyperparameters, the loss function, and the optimization algo

In [78]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [79]:
print(device)

cuda


In [80]:
lr = 0.001
b1 = 0.9
b2 = 0.999

In [90]:
s_c  = int(np.max(ind_stack))
s_in = 2304
s_h  = 64
s_out= 64
model   = RV_Model(s_c,s_in,s_h,s_out,device).to(device)
mse_loss = torch.nn.MSELoss().double()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(b1, b2))

In [91]:
batch_size = 4
n_cpu = 4
dataloader = DataLoader(
    testdata,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

validloader = DataLoader(
    validdata,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

<h2>Training Step</h2>
the working step!

In [92]:
import sys
import time
import datetime
import itertools

In [93]:
2*len(dataset)/400

124.4

In [94]:
n_epochs = 2
train_loss = []
valid_loss = []

valiter = itertools.cycle(validloader)
b_avg = 0.0
e_avg = 0.0

start_t = time.time()

directory, tail = path.split(filename)
for j,epoch in enumerate(range(n_epochs)):
    
    for i, batch in enumerate(dataloader):
        optimizer.zero_grad()
#         print(batch['indices'].dtype)
        y    = model(batch['img'].to(device),batch['indices']).squeeze()
        print(y.device)
#         print(model.device)
        loss = mse_loss(y.double(),batch['rvs'].to(device).double())
        loss.backward()
        optimizer.step()
        b_time = time.time() - start_t
        b_avg  = b_time / (i + 1 + (j * len(dataloader)))
        
        if (i % 5) == 0:
            # Validation checkpoint every 5 batches
            valbatch = next(valiter)
            model.eval()
            with torch.no_grad():
                y     = model(     valbatch['img'].to(device),valbatch['indices']).squeeze()
                print(y.device)
                vloss = mse_loss(y,valbatch['rvs'].to(device).double())
                
                r_time = b_avg * ((len(dataloader) * n_epochs) - (i + 1 + (j * len(dataloader))))
                sys.stdout.write(
                        "\r[Epoch %d/%d] [Batch %d/%d] [Train Loss: %f] [Valid Loss: %f] [BT: %s] [ET: %s] [RT: %s]"
                        % (
                            epoch,
                            n_epochs,
                            i,
                            len(dataloader),
                            loss.item(),
                            vloss.item(),
                            str(datetime.timedelta(seconds=b_avg)),
                            str(datetime.timedelta(seconds=e_avg)),
                            str(datetime.timedelta(seconds=r_time))
                        )  
                )
            
            model.train()
            
    train_loss.append(loss.item())
    valid_loss.append(vloss.item())
    if (j % 10) == 0:
        # Saving checkpoint every 10 epoches
        modelpath = '/scratch/mdd423/CNN_EPRV/models/rv_model_multi_{}_{}_{}_bcs.model'.format(tail[:-3],j,n_epochs)
        torch.save(model.state_dict(), modelpath)
    e_time = time.time() - start_t
    e_avg  = e_time/(j + 1)
train_loss = np.array(train_loss)
valid_loss = np.array(valid_loss)

RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 44.49 GiB total capacity; 43.14 GiB already allocated; 90.12 MiB free; 43.36 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [89]:
del model

In [99]:
thing = 0
for parameter in model.chunk_models.parameters():
#     print(parameter.shape)
    thing += np.product(parameter.shape)
    
thing2 = 0
for parameter in model.feature_model.parameters():
    thing2 += np.product(parameter.shape)
print('dense',thing/289)
print('conv   ' ,thing2)

dense 86636.41868512111
conv    960192


In [None]:
modelpath = '/scratch/mdd423/CNN_EPRV/models/rv_model_multi_{}_{}_{}_bcs.model'.format(tail[:-3],j,n_epochs)
torch.save(model.state_dict(), modelpath)

In [None]:
tlname = path.join(dir_name, tailname + '_multi_tl_bcs.nda')
vlname = path.join(dir_name, tailname + '_multi_vl_bcs.nda')
save(tlname,train_loss)
save(vlname,valid_loss)