In [1]:
import os
import time

import numpy as np
import torch
from torch import nn
from accelerate import Accelerator
from spender import SpectrumAutoencoder
from spender.data import desi_qso as desi 
from spender.util import mem_report

In [2]:
def prepare_train(seq,niter=800):
    for d in seq:
        if not "iteration" in d:d["iteration"]=niter
        if not "encoder" in d:d.update({"encoder":d["data"]})
    return seq

def build_ladder(train_sequence):
    n_iter = sum([item['iteration'] for item in train_sequence])

    ladder = np.zeros(n_iter,dtype='int')
    n_start = 0
    for i,mode in enumerate(train_sequence):
        n_end = n_start+mode['iteration']
        ladder[n_start:n_end]= i
        n_start = n_end
    return ladder

def get_all_parameters(models,instruments):
    model_params = []
    # multiple encoders
    for model in models:
        model_params += model.encoder.parameters()
        
    print(sum([p.numel() for p in model_params if p.requires_grad]))
    # 1 decoder
    model_params += model.decoder.parameters()
    dicts = [{'params':model_params}]

    n_parameters = sum([p.numel() for p in model_params if p.requires_grad])

    instr_params = []
    # instruments
    for inst in instruments:
        if inst==None:continue
        instr_params += inst.parameters()
        s = [p.numel() for p in inst.parameters()]
    if instr_params != []:
        dicts.append({'params':instr_params,'lr': 1e-4})
        n_parameters += sum([p.numel() for p in instr_params if p.requires_grad])
        print("parameter dict:",dicts[1])
    return dicts,n_parameters

def restframe_weight(model,mu=5000,sigma=2000,amp=30):
    x = model.decoder.wave_rest
    return amp*torch.exp(-(0.5*(x-mu)/sigma)**2)

def Loss(model, instrument, batch):
    spec, w, z = batch
    # need the latents later on if similarity=True
    s = model.encode(spec)
    
    return model.loss(spec, w, instrument, z=z, s=s)

def checkpoint(accelerator, args, optimizer, scheduler, n_encoder, outfile, losses):
    unwrapped = [accelerator.unwrap_model(args_i).state_dict() for args_i in args]

    accelerator.save({
        "model": unwrapped,
        "losses": losses,
    }, outfile)
    return

def load_model(filename, models, instruments):
    device = instruments[0].wave_obs.device
    model_struct = torch.load(filename, map_location=device)
    #wave_rest = model_struct['model'][0]['decoder.wave_rest']
    for i, model in enumerate(models):
        # backwards compat: encoder.mlp instead of encoder.mlp.mlp
        if 'encoder.mlp.mlp.0.weight' in model_struct['model'][i].keys():
            from collections import OrderedDict
            model_struct['model'][i] = OrderedDict([(k.replace('mlp.mlp', 'mlp'), v) for k, v in model_struct['model'][i].items()])
        # backwards compat: add instrument to encoder
        try:
            model.load_state_dict(model_struct['model'][i], strict=False)
        except RuntimeError:
            model_struct['model'][i]['encoder.instrument.wave_obs']= instruments[i].wave_obs
            model_struct['model'][i]['encoder.instrument.skyline_mask']= instruments[i].skyline_mask
            model.load_state_dict(model_struct[i]['model'], strict=False)

    losses = model_struct['losses']
    return models, losses

In [3]:
z_max = 2.1
_dir = '/tigress/chhahn/spender_qso/train'
outfile = '/tigress/chhahn/spender_qso/train/models/testing.pt'
latents = 10 
lr = 1e-3

# define instruments
instruments = [ desi.DESI() ]
n_encoder = len(instruments)

# data loaders
batch_size = 256
trainloaders = [ inst.get_data_loader(_dir, tag="qso_lowz", which="train",  batch_size=batch_size, shuffle=True, shuffle_instance=True) for inst in instruments ]
validloaders = [ inst.get_data_loader(_dir,  tag="qso_lowz", which="valid", batch_size=batch_size, shuffle=True, shuffle_instance=True) for inst in instruments ]

# restframe wavelength for reconstructed spectra
# Note: represents joint dataset wavelength range
lmbda_min = instruments[0].wave_obs[0]/(1.0+z_max) # 2000 A
lmbda_max = instruments[0].wave_obs[-1] # 9824 A
bins = 9780
wave_rest = torch.linspace(lmbda_min, lmbda_max, bins, dtype=torch.float32)
    
print ("Restframe:\t{:.0f} .. {:.0f} A ({} bins)".format(lmbda_min, lmbda_max, bins))

print(_dir) 


# define training sequence
FULL = {"data":[True],"decoder":True}
train_sequence = prepare_train([FULL])

# define and train the model
n_hidden = (64, 128, 1024)
models = [ SpectrumAutoencoder(instrument,
                               wave_rest,
                               n_latent=latents,
                               n_hidden=n_hidden,
                               act=[nn.LeakyReLU()]*(len(n_hidden)+1)
                               )
          for instrument in instruments ]

Restframe:	1161 .. 9824 A (9780 bins)
/tigress/chhahn/spender_qso/train


In [4]:
n_epoch = sum([item['iteration'] for item in train_sequence])
init_t = time.time()
print("torch.cuda.device_count():",torch.cuda.device_count())
print (f"--- Model {outfile} ---")

torch.cuda.device_count(): 1
--- Model /tigress/chhahn/spender_qso/train/models/testing.pt ---


In [5]:
n_encoder = len(models)
model_parameters, n_parameters = get_all_parameters(models,instruments)

print("model parameters:", n_parameters)
mem_report()

ladder = build_ladder(train_sequence)
optimizer = torch.optim.Adam(model_parameters, lr=lr, eps=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr,
                                          total_steps=n_epoch)

accelerator = Accelerator(mixed_precision='fp16')
models = [accelerator.prepare(model) for model in models]
instruments = [accelerator.prepare(instrument) for instrument in instruments]
trainloaders = [accelerator.prepare(loader) for loader in trainloaders]
validloaders = [accelerator.prepare(loader) for loader in validloaders]
optimizer = accelerator.prepare(optimizer)

3159178
model parameters: 13324798
CPU RAM Free: 796.7 GB
GPU 0 ... Mem Free: 40506MB / 40960MB | Utilization   0%


In [6]:
# track training and validation loss
detailed_loss = np.zeros((2, n_encoder, n_epoch))

for epoch_ in range(n_epoch):
    mem_report()
    mode = train_sequence[ladder[epoch_]]

    # turn on/off model decoder
    for p in models[0].decoder.parameters():
        p.requires_grad = True #mode['decoder']

    # turn on/off encoder
    for p in models[0].encoder.parameters():
        p.requires_grad = True

    models[0].train()
    instruments[0].train()

    n_sample = 0
    for k, batch in enumerate(trainloaders[0]):
        loss = Loss(models[0], instruments[0], batch)
        
        accelerator.backward(loss)
        # clip gradients: stabilizes training with similarity
        accelerator.clip_grad_norm_(model_parameters[0]['params'], 1.0)
        # once per batch
        optimizer.step()
        optimizer.zero_grad()

        # logging: training
        detailed_loss[0,0,epoch_] += loss #tuple( l.item() if hasattr(l, 'item') else 0 for l in losses )
        n_sample += batch_size

    detailed_loss[0,0,epoch_] /= n_sample

    scheduler.step()

    with torch.no_grad():
        models[0].eval()
        instruments[0].eval()

        n_sample = 0
        for k, batch in enumerate(validloaders[0]):
            loss = Loss(models[0], instruments[0], batch)
            # logging: validation
            detailed_loss[1,0,epoch_] += loss #tuple( l.item() if hasattr(l, 'item') else 0 for l in losses )
            n_sample += batch_size

        detailed_loss[1,0,epoch_] /= n_sample

    losses = tuple(detailed_loss[0, :, epoch_])
    vlosses = tuple(detailed_loss[1, :, epoch_])
    print('====> Epoch: %i' % (epoch_))
    print('TRAINING Losses:', losses)
    print('VALIDATION Losses:', vlosses)

    #if epoch_ % 5 == 0 or epoch_ == n_epoch - 1:
    #    args = models
    #    checkpoint(accelerator, args, optimizer, scheduler, n_encoder, outfile, detailed_loss)

CPU RAM Free: 795.7 GB
GPU 0 ... Mem Free: 39886MB / 40960MB | Utilization   2%
====> Epoch: 0
TRAINING Losses: (11.940950007530713,)
VALIDATION Losses: (12.78382681512688,)
CPU RAM Free: 794.0 GB
GPU 0 ... Mem Free: 16204MB / 40960MB | Utilization  59%
====> Epoch: 1
TRAINING Losses: (11.71884135388156,)
VALIDATION Losses: (12.611339784035106,)
CPU RAM Free: 793.9 GB
GPU 0 ... Mem Free: 16202MB / 40960MB | Utilization  59%
====> Epoch: 2
TRAINING Losses: (9.090837845454892,)
VALIDATION Losses: (8.020053036817343,)
CPU RAM Free: 793.9 GB
GPU 0 ... Mem Free: 16202MB / 40960MB | Utilization  59%
====> Epoch: 3
TRAINING Losses: (6.076129480193752,)
VALIDATION Losses: (6.048764365369714,)
CPU RAM Free: 794.0 GB
GPU 0 ... Mem Free: 16202MB / 40960MB | Utilization  59%
====> Epoch: 4
TRAINING Losses: (4.21785553758295,)
VALIDATION Losses: (3.987391852109502,)
CPU RAM Free: 793.9 GB
GPU 0 ... Mem Free: 16202MB / 40960MB | Utilization  59%
====> Epoch: 5
TRAINING Losses: (4.021403152058578,)
V

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/chhahn/projects/spender_qso/spender/util.py", line 201, in interp1d
        torch.jit.fork(interp1d_single, x[i], y[i], target[i], mask) for i in range(bs)
    ]
    itp = torch.stack([torch.jit.wait(f) for f in futures])
                       ~~~~~~~~~~~~~~ <--- HERE

    return itp
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/chhahn/projects/spender_qso/spender/util.py", line 199, in <forked function>
    # this is apparantly how parallelism works in pytorch?
    futures = [
        torch.jit.fork(interp1d_single, x[i], y[i], target[i], mask) for i in range(bs)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    ]
    itp = torch.stack([torch.jit.wait(f) for f in futures])
  File "/home/chhahn/projects/spender_qso/spender/util.py", line 129, in interp1d_single
    b = y[:-1] - (m * x[:-1])

    idx = torch.sum(torch.ge(target[:, None], x[None, :]), 1) - 1
          ~~~~~~~~~ <--- HERE
    idx = torch.clamp(idx, 0, len(m) - 1)
RuntimeError: CUDA out of memory. Tried to allocate 582.00 MiB (GPU 0; 39.56 GiB total capacity; 36.85 GiB already allocated; 518.00 MiB free; 38.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

