In [1]:
from pathlib import Path

import torch
import numpy as np
import random
import pickle
from absl import logging
from absl.flags import FLAGS
from cellot import losses
from cellot.utils.loaders import load
from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance
from cellot.train.summary import Logger
from cellot.data.utils import cast_loader_to_iterator
from cellot.models.ae import compute_scgen_shift
from tqdm import trange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TARGET = 'all' # 'all' denotes all drugs

In [3]:
import omegaconf

yaml_str = f"""
model:
   name: scgen
   beta: 0.0
   dropout: 0.0
   hidden_units: [512, 512]
   latent_dim: 50

optim:
   lr: 0.001
   optimizer: Adam
   weight_decay: 1.0e-05

scheduler:
   gamma: 0.5
   step_size: 100000

training:
  cache_freq: 10000
  eval_freq: 2500
  logs_freq: 250
  n_iters: 250000

data:
  type: cell
  source: control
  condition: drug
  path: /Mounts/rbg-storage1/users/johnyang/cellot/datasets/scrna-sciplex3/hvg.h5ad
  target: {TARGET}

datasplit:
    groupby: drug   
    name: train_test
    test_size: 0.2
    random_state: 0

dataloader:
    batch_size: 256
    shuffle: true
"""

config = omegaconf.OmegaConf.create(yaml_str)

### Utils

In [4]:
def load_lr_scheduler(optim, config):
    if "scheduler" not in config:
        return None

    return torch.optim.lr_scheduler.StepLR(optim, **config.scheduler)

def check_loss(*args):
    for arg in args:
        if torch.isnan(arg):
            raise ValueError


def load_item_from_save(path, key, default):
    path = Path(path)
    if not path.exists():
        return default

    ckpt = torch.load(path)
    if key not in ckpt:
        logging.warn(f"'{key}' not found in ckpt: {str(path)}")
        return default

    return ckpt[key]

In [5]:
import cellot.models
from cellot.data.cell import load_cell_data


def load_data(config, **kwargs):
    data_type = config.get("data.type", "cell")
    if data_type in ["cell", "cell-merged", "tupro-cohort"]:
        loadfxn = load_cell_data

    elif data_type == "toy":
        loadfxn = load_toy_data

    else:
        raise ValueError

    return loadfxn(config, **kwargs)


def load_model(config, restore=None, **kwargs):
    name = config.model.name
    if name == "cellot":
        loadfxn = cellot.models.load_cellot_model

    elif name == "scgen":
        loadfxn = cellot.models.load_autoencoder_model

    elif name == "cae":
        loadfxn = cellot.models.load_autoencoder_model

    elif name == "popalign":
        loadfxn = cellot.models.load_popalign_model

    else:
        raise ValueError

    return loadfxn(config, restore=restore, **kwargs)


def load(config, restore=None, include_model_kwargs=False, **kwargs):

    loader, model_kwargs = load_data(config, include_model_kwargs=True, **kwargs)

    model, opt = load_model(config, restore=restore, **model_kwargs)

    if include_model_kwargs:
        return model, opt, loader, model_kwargs

    return model, opt, loader

### Training

In [6]:
def train_auto_encoder(outdir, config):
    def state_dict(model, optim, **kwargs):
        state = {
            "model_state": model.state_dict(),
            "optim_state": optim.state_dict(),
        }

        if hasattr(model, "code_means"):
            state["code_means"] = model.code_means

        state.update(kwargs)

        return state

    def evaluate(vinputs):
        with torch.no_grad():
            loss, comps, _ = model(vinputs)
            loss = loss.mean()
            comps = {k: v.mean().item() for k, v in comps._asdict().items()}
            check_loss(loss)
            logger.log("eval", loss=loss.item(), step=step, **comps)
        return loss

    logger = Logger(outdir / "cache/scalars")
    cachedir = outdir / "cache"
    model, optim, loader = load(config, restore=cachedir / "last.pt")

    iterator = cast_loader_to_iterator(loader, cycle_all=True)
    scheduler = load_lr_scheduler(optim, config)

    n_iters = config.training.n_iters
    step = load_item_from_save(cachedir / "last.pt", "step", 0)
    if scheduler is not None and step > 0:
        scheduler.last_epoch = step

    best_eval_loss = load_item_from_save(
        cachedir / "model.pt", "best_eval_loss", np.inf
    )

    eval_loss = best_eval_loss

    ticker = trange(step, n_iters, initial=step, total=n_iters)
    for step in ticker:

        model.train()
        inputs = next(iterator.train)
        optim.zero_grad()
        loss, comps, _ = model(inputs)
        loss = loss.mean()
        comps = {k: v.mean().item() for k, v in comps._asdict().items()}
        loss.backward()
        optim.step()
        check_loss(loss)

        if step % config.training.logs_freq == 0:
            # log to logger object
            logger.log("train", loss=loss.item(), step=step, **comps)

        if step % config.training.eval_freq == 0:
            model.eval()
            eval_loss = evaluate(next(iterator.test))
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                sd = state_dict(model, optim, step=(step + 1), eval_loss=eval_loss)

                torch.save(sd, cachedir / "model.pt")

        if step % config.training.cache_freq == 0:
            torch.save(state_dict(model, optim, step=(step + 1)), cachedir / "last.pt")

            logger.flush()

        if scheduler is not None:
            scheduler.step()

    if config.model.name == "scgen" and config.get("compute_scgen_shift", True):
        labels = loader.train.dataset.adata.obs[config.data.condition]
        compute_scgen_shift(model, loader.train.dataset, labels=labels)

    torch.save(state_dict(model, optim, step=step), cachedir / "last.pt")

    logger.flush()

### Outdir

In [7]:
import torch
import GPUtil
import os

def get_free_gpu():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    # Set environment variables for which GPUs to use.
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    chosen_gpu = ''.join(
        [str(x) for x in GPUtil.getAvailable(order='memory')])
    os.environ["CUDA_VISIBLE_DEVICES"] = chosen_gpu
    print(f"Using GPUs: {chosen_gpu}")
    return chosen_gpu

In [8]:
device = f'cuda:{get_free_gpu()}'

cuda
Using GPUs: 1


In [9]:
from pathlib import Path
outdir_path = '/Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/out_of_the_box_ae'
outdir = Path(outdir_path)

In [10]:
from tqdm import tqdm

In [11]:
def evaluate(vinputs, tmodel):
    with torch.no_grad():
        loss, comps, _ = tmodel(vinputs)
        loss = loss.mean()
        comps = {k: v.mean().item() for k, v in comps._asdict().items()}
        check_loss(loss)
        logger.log("eval", loss=loss.item(), step=step, **comps)
    return loss

logger = Logger(outdir / "cache/scalars")
cachedir = outdir / "cache"
model, optim, loader = load(config, restore=cachedir / "last.pt")

iterator = cast_loader_to_iterator(loader, cycle_all=True)
scheduler = load_lr_scheduler(optim, config)

n_iters = config.training.n_iters
step = load_item_from_save(cachedir / "last.pt", "step", 0)
if scheduler is not None and step > 0:
    scheduler.last_epoch = step

best_eval_loss = load_item_from_save(
    cachedir / "model.pt", "best_eval_loss", np.inf
)

eval_loss = best_eval_loss

# ticker = trange(step, n_iters, initial=step, total=n_iters)
eval_losses = []

for batch in tqdm(iterator.test):
    model.eval()
    eval_loss = evaluate(batch, model)
    eval_losses.append(eval_loss)
    if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        
    print(eval_loss)
        # sd = state_dict(model, optim, step=(step + 1), eval_loss=eval_loss)
        # torch.save(sd, cachedir / "model.pt")

2023-06-23 11:13:00,698 Loaded cell data with TARGET all and OBS SHAPE (762039, 16)
  logging.warn(f"'{key}' not found in ckpt: {str(path)}")
2023-06-23 11:13:07,748 'best_eval_loss' not found in ckpt: /Mounts/rbg-storage1/users/johnyang/cellot/results/sciplex3/out_of_the_box_ae/cache/model.pt
25it [00:00, 122.16it/s]

tensor(0.0872)
tensor(0.0766)
tensor(0.0877)
tensor(0.0767)
tensor(0.0845)
tensor(0.0814)
tensor(0.0768)
tensor(0.0843)
tensor(0.0896)
tensor(0.0873)
tensor(0.0888)
tensor(0.0910)
tensor(0.0797)
tensor(0.0794)
tensor(0.0844)
tensor(0.0785)
tensor(0.0858)
tensor(0.0844)
tensor(0.0938)
tensor(0.0850)
tensor(0.0855)
tensor(0.0792)
tensor(0.0781)
tensor(0.0957)
tensor(0.0830)
tensor(0.0853)
tensor(0.0762)


54it [00:00, 134.88it/s]

tensor(0.0845)
tensor(0.0858)
tensor(0.0850)
tensor(0.0840)
tensor(0.0833)
tensor(0.0812)
tensor(0.0850)
tensor(0.0840)
tensor(0.0770)
tensor(0.0866)
tensor(0.0892)
tensor(0.0881)
tensor(0.0877)
tensor(0.0835)
tensor(0.0798)
tensor(0.0843)
tensor(0.0808)
tensor(0.0786)
tensor(0.0842)
tensor(0.0827)
tensor(0.0756)
tensor(0.0846)
tensor(0.0740)
tensor(0.0817)
tensor(0.0874)
tensor(0.0762)
tensor(0.0811)
tensor(0.0838)
tensor(0.0804)


86it [00:00, 145.74it/s]

tensor(0.0933)
tensor(0.0779)
tensor(0.0748)
tensor(0.0877)
tensor(0.0965)
tensor(0.0799)
tensor(0.0914)
tensor(0.0850)
tensor(0.0797)
tensor(0.0818)
tensor(0.0791)
tensor(0.0850)
tensor(0.0827)
tensor(0.0801)
tensor(0.0838)
tensor(0.0845)
tensor(0.0896)
tensor(0.0858)
tensor(0.0910)
tensor(0.0797)
tensor(0.0716)
tensor(0.0809)
tensor(0.0826)
tensor(0.0763)
tensor(0.0827)
tensor(0.0737)
tensor(0.0779)
tensor(0.0874)
tensor(0.0889)
tensor(0.0851)
tensor(0.0785)


117it [00:00, 146.89it/s]

tensor(0.0765)
tensor(0.0768)
tensor(0.0839)
tensor(0.0798)
tensor(0.0940)
tensor(0.0859)
tensor(0.0808)
tensor(0.0873)
tensor(0.0809)
tensor(0.0809)
tensor(0.0857)
tensor(0.0923)
tensor(0.0895)
tensor(0.0817)
tensor(0.0797)
tensor(0.0823)
tensor(0.0815)
tensor(0.0753)
tensor(0.0914)
tensor(0.0796)
tensor(0.0858)
tensor(0.0904)
tensor(0.0704)
tensor(0.0821)
tensor(0.0784)
tensor(0.0816)
tensor(0.0777)
tensor(0.0816)
tensor(0.0860)
tensor(0.0875)
tensor(0.0761)


148it [00:01, 146.42it/s]

tensor(0.0808)
tensor(0.0902)
tensor(0.0783)
tensor(0.0802)
tensor(0.0801)
tensor(0.0741)
tensor(0.0750)
tensor(0.0791)
tensor(0.0870)
tensor(0.0851)
tensor(0.0768)
tensor(0.0829)
tensor(0.0867)
tensor(0.0808)
tensor(0.0839)
tensor(0.0853)
tensor(0.0860)
tensor(0.0975)
tensor(0.0878)
tensor(0.0830)
tensor(0.0788)
tensor(0.0858)
tensor(0.0742)
tensor(0.0850)
tensor(0.0808)
tensor(0.0841)
tensor(0.0903)
tensor(0.0788)
tensor(0.0833)
tensor(0.0871)


180it [00:01, 149.45it/s]

tensor(0.0803)
tensor(0.0753)
tensor(0.0891)
tensor(0.0787)
tensor(0.0718)
tensor(0.0806)
tensor(0.0821)
tensor(0.0762)
tensor(0.0817)
tensor(0.0858)
tensor(0.0893)
tensor(0.0786)
tensor(0.0832)
tensor(0.0876)
tensor(0.0826)
tensor(0.0737)
tensor(0.0890)
tensor(0.0951)
tensor(0.0837)
tensor(0.0826)
tensor(0.0809)
tensor(0.0789)
tensor(0.0814)
tensor(0.0789)
tensor(0.0800)
tensor(0.0835)
tensor(0.0779)
tensor(0.0772)
tensor(0.0771)
tensor(0.0871)
tensor(0.0789)
tensor(0.0841)
tensor(0.0759)
tensor(0.0791)
tensor(0.0848)
tensor(0.0738)
tensor(0.0853)
tensor(0.0817)
tensor(0.0788)
tensor(0.0809)
tensor(0.1001)
tensor(0.0853)
tensor(0.0811)
tensor(0.0758)
tensor(0.0815)
tensor(0.0747)


210it [00:01, 112.00it/s]

tensor(0.0722)
tensor(0.0799)
tensor(0.0828)
tensor(0.0870)
tensor(0.0831)
tensor(0.0815)
tensor(0.0845)
tensor(0.0836)
tensor(0.0857)
tensor(0.0757)
tensor(0.0839)
tensor(0.0756)
tensor(0.0904)
tensor(0.0883)
tensor(0.0835)
tensor(0.0891)
tensor(0.0864)
tensor(0.0719)
tensor(0.0872)
tensor(0.0809)
tensor(0.0773)
tensor(0.0752)
tensor(0.0875)
tensor(0.0813)
tensor(0.0756)
tensor(0.0875)
tensor(0.0893)
tensor(0.0817)
tensor(0.0745)
tensor(0.0860)


242it [00:01, 128.51it/s]

tensor(0.0899)
tensor(0.0808)
tensor(0.0805)
tensor(0.0805)
tensor(0.0744)
tensor(0.0810)
tensor(0.0731)
tensor(0.0784)
tensor(0.0761)
tensor(0.0788)
tensor(0.0843)
tensor(0.0816)
tensor(0.0768)
tensor(0.0810)
tensor(0.0790)
tensor(0.0799)
tensor(0.0808)
tensor(0.0861)
tensor(0.0717)
tensor(0.0775)
tensor(0.0776)
tensor(0.0818)
tensor(0.0859)
tensor(0.0831)
tensor(0.0814)
tensor(0.0864)
tensor(0.0704)
tensor(0.0808)
tensor(0.0731)
tensor(0.0821)


274it [00:02, 142.20it/s]

tensor(0.0835)
tensor(0.0912)
tensor(0.0825)
tensor(0.0883)
tensor(0.0773)
tensor(0.0893)
tensor(0.0822)
tensor(0.0804)
tensor(0.0817)
tensor(0.0843)
tensor(0.0852)
tensor(0.0809)
tensor(0.0879)
tensor(0.0816)
tensor(0.0818)
tensor(0.0857)
tensor(0.0853)
tensor(0.0864)
tensor(0.0889)
tensor(0.0786)
tensor(0.0925)
tensor(0.0838)
tensor(0.0792)
tensor(0.0741)
tensor(0.0795)
tensor(0.0782)
tensor(0.0836)
tensor(0.0877)
tensor(0.0874)
tensor(0.0908)
tensor(0.0759)
tensor(0.0841)
tensor(0.0880)
tensor(0.0868)


309it [00:02, 155.45it/s]

tensor(0.0916)
tensor(0.0800)
tensor(0.0772)
tensor(0.0757)
tensor(0.0753)
tensor(0.0897)
tensor(0.0703)
tensor(0.0900)
tensor(0.0782)
tensor(0.0820)
tensor(0.0792)
tensor(0.0779)
tensor(0.0807)
tensor(0.0780)
tensor(0.0795)
tensor(0.0841)
tensor(0.0817)
tensor(0.0870)
tensor(0.0836)
tensor(0.0834)
tensor(0.0821)
tensor(0.0800)
tensor(0.0823)
tensor(0.0822)
tensor(0.0836)
tensor(0.0842)
tensor(0.0775)
tensor(0.0749)
tensor(0.0871)
tensor(0.0849)
tensor(0.0796)
tensor(0.0928)
tensor(0.0851)
tensor(0.0829)
tensor(0.0835)


345it [00:02, 163.72it/s]

tensor(0.0782)
tensor(0.0812)
tensor(0.0780)
tensor(0.0804)
tensor(0.0835)
tensor(0.0852)
tensor(0.0783)
tensor(0.0814)
tensor(0.0851)
tensor(0.0778)
tensor(0.0793)
tensor(0.0751)
tensor(0.0844)
tensor(0.0883)
tensor(0.0777)
tensor(0.0865)
tensor(0.0783)
tensor(0.0790)
tensor(0.0735)
tensor(0.0833)
tensor(0.0801)
tensor(0.0770)
tensor(0.0796)
tensor(0.0774)
tensor(0.0778)
tensor(0.0825)
tensor(0.0857)
tensor(0.0819)
tensor(0.0802)
tensor(0.0907)
tensor(0.0867)
tensor(0.0842)
tensor(0.0808)
tensor(0.0766)
tensor(0.0785)


379it [00:02, 163.38it/s]

tensor(0.0845)
tensor(0.0861)
tensor(0.0830)
tensor(0.0862)
tensor(0.0825)
tensor(0.0868)
tensor(0.0832)
tensor(0.0864)
tensor(0.0819)
tensor(0.0898)
tensor(0.0819)
tensor(0.0801)
tensor(0.0760)
tensor(0.0838)
tensor(0.0911)
tensor(0.0807)
tensor(0.0680)
tensor(0.0824)
tensor(0.0902)
tensor(0.0766)
tensor(0.0855)
tensor(0.0847)
tensor(0.0808)
tensor(0.0847)
tensor(0.0870)
tensor(0.0789)
tensor(0.0777)
tensor(0.0843)
tensor(0.0777)
tensor(0.0873)
tensor(0.0802)
tensor(0.0804)
tensor(0.0771)


413it [00:02, 163.46it/s]

tensor(0.0738)
tensor(0.0791)
tensor(0.0791)
tensor(0.0850)
tensor(0.0867)
tensor(0.0834)
tensor(0.0891)
tensor(0.0844)
tensor(0.0790)
tensor(0.0837)
tensor(0.0780)
tensor(0.0802)
tensor(0.0764)
tensor(0.0911)
tensor(0.0868)
tensor(0.0921)
tensor(0.0818)
tensor(0.0807)
tensor(0.0856)
tensor(0.0784)
tensor(0.0753)
tensor(0.0898)
tensor(0.0833)
tensor(0.0720)
tensor(0.0885)
tensor(0.0839)
tensor(0.0744)
tensor(0.0806)
tensor(0.0898)
tensor(0.0713)
tensor(0.0849)
tensor(0.0889)
tensor(0.0928)


448it [00:03, 164.45it/s]

tensor(0.0825)
tensor(0.0899)
tensor(0.0797)
tensor(0.0778)
tensor(0.0809)
tensor(0.0789)
tensor(0.0748)
tensor(0.0817)
tensor(0.0817)
tensor(0.0817)
tensor(0.0831)
tensor(0.0781)
tensor(0.0819)
tensor(0.0892)
tensor(0.0857)
tensor(0.0896)
tensor(0.0762)
tensor(0.0854)
tensor(0.0874)
tensor(0.0844)
tensor(0.0832)
tensor(0.0761)
tensor(0.0829)
tensor(0.0743)
tensor(0.0887)
tensor(0.0735)
tensor(0.0683)
tensor(0.0840)
tensor(0.0838)
tensor(0.0844)
tensor(0.0790)
tensor(0.0810)
tensor(0.0778)
tensor(0.0819)
tensor(0.0884)
tensor(0.0783)


483it [00:03, 166.73it/s]

tensor(0.0860)
tensor(0.0773)
tensor(0.0897)
tensor(0.0915)
tensor(0.0800)
tensor(0.0830)
tensor(0.0821)
tensor(0.0823)
tensor(0.0899)
tensor(0.0790)
tensor(0.0839)
tensor(0.0845)
tensor(0.0848)
tensor(0.0802)
tensor(0.0827)
tensor(0.0832)
tensor(0.0784)
tensor(0.0822)
tensor(0.0817)
tensor(0.0832)
tensor(0.0733)
tensor(0.0802)
tensor(0.0780)
tensor(0.0908)
tensor(0.0861)
tensor(0.0851)
tensor(0.0788)
tensor(0.0835)
tensor(0.0907)
tensor(0.0793)
tensor(0.0852)
tensor(0.0862)
tensor(0.0835)


517it [00:03, 162.81it/s]

tensor(0.0844)
tensor(0.0916)
tensor(0.0861)
tensor(0.0793)
tensor(0.0828)
tensor(0.0788)
tensor(0.0835)
tensor(0.0824)
tensor(0.0932)
tensor(0.0804)
tensor(0.0784)
tensor(0.0831)
tensor(0.0762)
tensor(0.0774)
tensor(0.0803)
tensor(0.0881)
tensor(0.0719)
tensor(0.0803)
tensor(0.0758)
tensor(0.0851)
tensor(0.0842)
tensor(0.0876)
tensor(0.0800)
tensor(0.0840)
tensor(0.0800)
tensor(0.0943)
tensor(0.0798)
tensor(0.0892)
tensor(0.0909)
tensor(0.0779)
tensor(0.0849)
tensor(0.0846)
tensor(0.0849)
tensor(0.0807)


551it [00:03, 165.40it/s]

tensor(0.0935)
tensor(0.0816)
tensor(0.0888)
tensor(0.0820)
tensor(0.0758)
tensor(0.0833)
tensor(0.0778)
tensor(0.0845)
tensor(0.0848)
tensor(0.0838)
tensor(0.0834)
tensor(0.0752)
tensor(0.0825)
tensor(0.0809)
tensor(0.0800)
tensor(0.0850)
tensor(0.0915)
tensor(0.0878)
tensor(0.0768)
tensor(0.0818)
tensor(0.0850)
tensor(0.0812)
tensor(0.0746)
tensor(0.0919)
tensor(0.0779)
tensor(0.0821)
tensor(0.0853)
tensor(0.0831)
tensor(0.0877)
tensor(0.0821)
tensor(0.0823)
tensor(0.0861)
tensor(0.0863)
tensor(0.0759)
tensor(0.0811)


586it [00:03, 167.22it/s]

tensor(0.0814)
tensor(0.0876)
tensor(0.0813)
tensor(0.0825)
tensor(0.0775)
tensor(0.0789)
tensor(0.0822)
tensor(0.0762)
tensor(0.0871)
tensor(0.0912)
tensor(0.0881)
tensor(0.0766)
tensor(0.0792)
tensor(0.0868)
tensor(0.0830)
tensor(0.0881)
tensor(0.0830)
tensor(0.0834)
tensor(0.0860)
tensor(0.0851)
tensor(0.0793)
tensor(0.0821)
tensor(0.0847)
tensor(0.0807)
tensor(0.0816)
tensor(0.0802)
tensor(0.0755)
tensor(0.0761)
tensor(0.0761)
tensor(0.0843)
tensor(0.0795)
tensor(0.0678)
tensor(0.0833)


621it [00:04, 162.86it/s]

tensor(0.0854)
tensor(0.0838)
tensor(0.0691)
tensor(0.0819)
tensor(0.0781)
tensor(0.0871)
tensor(0.0743)
tensor(0.0741)
tensor(0.0816)
tensor(0.0816)
tensor(0.0850)
tensor(0.0756)
tensor(0.0753)
tensor(0.0752)
tensor(0.0773)
tensor(0.0830)
tensor(0.0892)
tensor(0.0818)
tensor(0.0775)
tensor(0.0857)
tensor(0.0905)
tensor(0.0835)
tensor(0.0816)
tensor(0.0882)
tensor(0.0825)
tensor(0.0819)
tensor(0.0758)
tensor(0.0837)
tensor(0.0866)
tensor(0.0810)
tensor(0.0887)
tensor(0.0850)
tensor(0.0844)
tensor(0.0690)
tensor(0.0878)


655it [00:04, 164.63it/s]

tensor(0.0810)
tensor(0.0875)
tensor(0.0767)
tensor(0.0759)
tensor(0.0785)
tensor(0.0837)
tensor(0.0753)
tensor(0.0835)
tensor(0.0800)
tensor(0.0772)
tensor(0.0962)
tensor(0.0802)
tensor(0.0865)
tensor(0.0859)
tensor(0.0875)
tensor(0.0829)
tensor(0.0780)
tensor(0.0820)
tensor(0.0828)
tensor(0.0820)
tensor(0.0893)
tensor(0.0874)
tensor(0.0797)
tensor(0.0837)
tensor(0.0837)
tensor(0.0847)
tensor(0.0783)
tensor(0.0801)
tensor(0.0814)
tensor(0.0768)
tensor(0.0805)
tensor(0.0814)
tensor(0.0833)
tensor(0.0728)


689it [00:04, 164.86it/s]

tensor(0.0812)
tensor(0.0791)
tensor(0.0801)
tensor(0.0846)
tensor(0.0742)
tensor(0.0845)
tensor(0.0837)
tensor(0.0768)
tensor(0.0828)
tensor(0.0818)
tensor(0.0817)
tensor(0.0728)
tensor(0.0752)
tensor(0.0884)
tensor(0.0790)
tensor(0.0850)
tensor(0.0712)
tensor(0.0770)
tensor(0.0872)
tensor(0.0764)
tensor(0.0796)
tensor(0.0735)
tensor(0.0887)
tensor(0.0774)
tensor(0.0873)
tensor(0.0770)
tensor(0.0839)
tensor(0.0895)
tensor(0.0831)
tensor(0.0713)
tensor(0.0817)
tensor(0.0850)
tensor(0.0781)


723it [00:04, 159.75it/s]

tensor(0.0830)
tensor(0.0778)
tensor(0.0792)
tensor(0.0831)
tensor(0.0742)
tensor(0.0770)
tensor(0.0755)
tensor(0.0804)
tensor(0.0898)
tensor(0.0834)
tensor(0.0747)
tensor(0.0808)
tensor(0.0928)
tensor(0.0807)
tensor(0.0819)
tensor(0.0762)
tensor(0.0835)
tensor(0.0878)
tensor(0.0795)
tensor(0.0715)
tensor(0.0814)
tensor(0.0890)
tensor(0.0831)
tensor(0.0813)
tensor(0.0879)
tensor(0.0871)
tensor(0.0917)
tensor(0.0814)
tensor(0.0840)
tensor(0.0796)
tensor(0.0776)
tensor(0.0770)


756it [00:04, 160.00it/s]

tensor(0.0800)
tensor(0.0868)
tensor(0.0800)
tensor(0.0808)
tensor(0.0787)
tensor(0.0835)
tensor(0.0813)
tensor(0.0864)
tensor(0.0832)
tensor(0.0749)
tensor(0.0872)
tensor(0.0831)
tensor(0.0909)
tensor(0.0894)
tensor(0.0789)
tensor(0.0784)
tensor(0.0816)
tensor(0.0857)
tensor(0.0799)
tensor(0.0793)
tensor(0.0836)
tensor(0.0716)
tensor(0.0805)
tensor(0.0902)
tensor(0.0711)
tensor(0.0838)
tensor(0.0789)
tensor(0.0775)
tensor(0.0774)
tensor(0.0804)
tensor(0.0763)
tensor(0.0765)


790it [00:05, 162.03it/s]

tensor(0.0877)
tensor(0.0793)
tensor(0.0921)
tensor(0.0828)
tensor(0.0872)
tensor(0.0828)
tensor(0.0819)
tensor(0.0848)
tensor(0.0736)
tensor(0.0801)
tensor(0.0873)
tensor(0.0828)
tensor(0.0790)
tensor(0.0852)
tensor(0.0842)
tensor(0.0926)
tensor(0.0906)
tensor(0.0755)
tensor(0.0791)
tensor(0.0786)
tensor(0.0888)
tensor(0.0783)
tensor(0.0780)
tensor(0.0813)
tensor(0.0885)
tensor(0.0815)
tensor(0.0745)
tensor(0.0830)
tensor(0.0837)
tensor(0.0812)
tensor(0.0855)
tensor(0.0901)
tensor(0.0881)
tensor(0.0873)


824it [00:05, 162.66it/s]

tensor(0.0905)
tensor(0.0843)
tensor(0.0822)
tensor(0.0826)
tensor(0.0895)
tensor(0.0890)
tensor(0.0835)
tensor(0.0884)
tensor(0.0777)
tensor(0.0838)
tensor(0.0815)
tensor(0.0714)
tensor(0.0702)
tensor(0.0782)
tensor(0.0849)
tensor(0.0826)
tensor(0.0751)
tensor(0.0785)
tensor(0.0822)
tensor(0.0875)
tensor(0.0848)
tensor(0.0789)
tensor(0.0782)
tensor(0.0850)
tensor(0.0825)
tensor(0.0782)
tensor(0.0837)
tensor(0.0868)
tensor(0.0817)
tensor(0.0884)
tensor(0.0827)
tensor(0.0823)
tensor(0.0862)
tensor(0.0789)


859it [00:05, 166.59it/s]

tensor(0.0862)
tensor(0.0758)
tensor(0.0834)
tensor(0.0772)
tensor(0.0783)
tensor(0.0821)
tensor(0.0800)
tensor(0.0797)
tensor(0.0833)
tensor(0.0873)
tensor(0.0766)
tensor(0.0752)
tensor(0.0907)
tensor(0.0856)
tensor(0.0833)
tensor(0.0780)
tensor(0.0901)
tensor(0.0760)
tensor(0.0767)
tensor(0.0853)
tensor(0.0874)
tensor(0.0800)
tensor(0.0774)
tensor(0.0708)
tensor(0.0801)
tensor(0.0773)
tensor(0.0858)
tensor(0.0839)
tensor(0.0838)
tensor(0.0751)
tensor(0.0774)
tensor(0.0892)
tensor(0.0814)
tensor(0.0856)
tensor(0.0782)


894it [00:05, 168.33it/s]

tensor(0.0805)
tensor(0.0847)
tensor(0.0811)
tensor(0.0880)
tensor(0.0808)
tensor(0.0823)
tensor(0.0791)
tensor(0.0769)
tensor(0.0762)
tensor(0.0812)
tensor(0.0824)
tensor(0.0822)
tensor(0.0864)
tensor(0.0798)
tensor(0.0892)
tensor(0.0861)
tensor(0.0809)
tensor(0.0910)
tensor(0.0792)
tensor(0.0799)
tensor(0.0858)
tensor(0.0862)
tensor(0.0837)
tensor(0.0869)
tensor(0.0859)
tensor(0.0850)
tensor(0.0809)
tensor(0.0806)
tensor(0.0706)
tensor(0.0822)
tensor(0.0850)
tensor(0.0826)
tensor(0.0833)
tensor(0.0835)
tensor(0.0858)


928it [00:06, 164.63it/s]

tensor(0.0895)
tensor(0.0765)
tensor(0.0924)
tensor(0.0881)
tensor(0.0919)
tensor(0.0849)
tensor(0.0796)
tensor(0.0855)
tensor(0.0809)
tensor(0.0813)
tensor(0.0815)
tensor(0.0827)
tensor(0.0908)
tensor(0.0846)
tensor(0.0852)
tensor(0.0745)
tensor(0.0902)
tensor(0.0792)
tensor(0.0852)
tensor(0.0761)
tensor(0.0791)
tensor(0.0822)
tensor(0.0842)
tensor(0.0878)
tensor(0.0829)
tensor(0.0804)
tensor(0.0747)
tensor(0.0724)
tensor(0.0729)
tensor(0.0960)
tensor(0.0771)
tensor(0.0740)
tensor(0.0743)


961it [00:06, 155.49it/s]

tensor(0.0878)
tensor(0.0810)
tensor(0.0819)
tensor(0.0757)
tensor(0.0851)
tensor(0.0927)
tensor(0.0907)
tensor(0.0874)
tensor(0.0808)
tensor(0.0832)
tensor(0.0844)
tensor(0.0821)
tensor(0.0828)
tensor(0.0799)
tensor(0.0904)
tensor(0.0830)
tensor(0.0875)
tensor(0.0819)
tensor(0.0891)
tensor(0.0866)
tensor(0.0856)
tensor(0.0812)
tensor(0.0926)
tensor(0.0857)
tensor(0.0875)
tensor(0.0776)
tensor(0.0802)
tensor(0.0796)
tensor(0.0888)
tensor(0.0819)


993it [00:06, 155.66it/s]

tensor(0.0853)
tensor(0.0783)
tensor(0.0872)
tensor(0.0861)
tensor(0.0822)
tensor(0.0856)
tensor(0.0842)
tensor(0.1022)
tensor(0.0792)
tensor(0.0792)
tensor(0.0798)
tensor(0.0813)
tensor(0.0826)
tensor(0.0748)
tensor(0.0923)
tensor(0.0908)
tensor(0.0784)
tensor(0.0797)
tensor(0.0793)
tensor(0.0821)
tensor(0.0809)
tensor(0.0777)
tensor(0.0843)
tensor(0.0843)
tensor(0.0754)
tensor(0.0805)
tensor(0.0845)
tensor(0.0893)
tensor(0.0800)
tensor(0.0808)
tensor(0.0828)
tensor(0.0815)


1025it [00:06, 154.17it/s]

tensor(0.0827)
tensor(0.0796)
tensor(0.0840)
tensor(0.0919)
tensor(0.0824)
tensor(0.0808)
tensor(0.0792)
tensor(0.0877)
tensor(0.0859)
tensor(0.0723)
tensor(0.0866)
tensor(0.0870)
tensor(0.0881)
tensor(0.0878)
tensor(0.0819)
tensor(0.0824)
tensor(0.0819)
tensor(0.0826)
tensor(0.0850)
tensor(0.0767)
tensor(0.0883)
tensor(0.0862)
tensor(0.0800)
tensor(0.0805)
tensor(0.0807)
tensor(0.0789)
tensor(0.0830)
tensor(0.0802)
tensor(0.0829)
tensor(0.0821)
tensor(0.0905)


1057it [00:06, 155.96it/s]

tensor(0.0830)
tensor(0.0886)
tensor(0.0871)
tensor(0.0813)
tensor(0.0678)
tensor(0.0859)
tensor(0.0807)
tensor(0.0837)
tensor(0.0739)
tensor(0.0875)
tensor(0.0834)
tensor(0.0917)
tensor(0.0744)
tensor(0.0887)
tensor(0.0807)
tensor(0.0800)
tensor(0.0814)
tensor(0.0832)
tensor(0.0918)
tensor(0.0824)
tensor(0.0831)
tensor(0.0808)
tensor(0.0869)
tensor(0.0861)
tensor(0.0827)
tensor(0.0839)
tensor(0.0829)
tensor(0.0818)
tensor(0.0813)
tensor(0.0752)
tensor(0.0831)
tensor(0.0866)


1073it [00:06, 154.09it/s]

tensor(0.0807)
tensor(0.0764)
tensor(0.0901)
tensor(0.0821)
tensor(0.0823)
tensor(0.0828)
tensor(0.0815)
tensor(0.0867)
tensor(0.0804)
tensor(0.0950)
tensor(0.0880)
tensor(0.0852)
tensor(0.0821)
tensor(0.0807)
tensor(0.0824)
tensor(0.0750)
tensor(0.0854)
tensor(0.0823)
tensor(0.0786)
tensor(0.0924)
tensor(0.0818)
tensor(0.0756)
tensor(0.0832)
tensor(0.0886)
tensor(0.0870)
tensor(0.0836)
tensor(0.0823)
tensor(0.0901)
tensor(0.0894)
tensor(0.0807)
tensor(0.0820)


1105it [00:07, 155.93it/s]

tensor(0.0876)
tensor(0.0868)
tensor(0.0775)
tensor(0.0833)
tensor(0.0769)
tensor(0.0840)
tensor(0.0855)
tensor(0.0820)
tensor(0.0784)
tensor(0.0776)
tensor(0.0812)
tensor(0.0814)
tensor(0.0837)
tensor(0.0879)
tensor(0.0827)
tensor(0.0808)
tensor(0.0863)
tensor(0.0868)
tensor(0.0775)
tensor(0.0816)
tensor(0.0790)
tensor(0.0767)
tensor(0.0863)
tensor(0.0861)
tensor(0.0817)
tensor(0.0721)
tensor(0.0856)
tensor(0.0877)
tensor(0.0796)
tensor(0.0796)
tensor(0.0766)
tensor(0.0791)


1137it [00:07, 156.01it/s]

tensor(0.0801)
tensor(0.0808)
tensor(0.0947)
tensor(0.0863)
tensor(0.0819)
tensor(0.0836)
tensor(0.0744)
tensor(0.0663)
tensor(0.0856)
tensor(0.0860)
tensor(0.0736)
tensor(0.0827)
tensor(0.0747)
tensor(0.0827)
tensor(0.0879)
tensor(0.0846)
tensor(0.0823)
tensor(0.0871)
tensor(0.0878)
tensor(0.0810)
tensor(0.0852)
tensor(0.0898)
tensor(0.0864)
tensor(0.0724)
tensor(0.0871)
tensor(0.0784)
tensor(0.0900)
tensor(0.0861)
tensor(0.0790)
tensor(0.0820)
tensor(0.0849)
tensor(0.0844)


1169it [00:07, 156.97it/s]

tensor(0.0840)
tensor(0.0745)
tensor(0.0873)
tensor(0.0877)
tensor(0.0836)
tensor(0.0796)
tensor(0.0828)
tensor(0.0804)
tensor(0.0847)
tensor(0.0854)
tensor(0.0821)
tensor(0.0799)
tensor(0.0792)
tensor(0.0761)
tensor(0.0910)
tensor(0.0837)
tensor(0.0884)
tensor(0.0843)
tensor(0.0821)
tensor(0.0731)
tensor(0.0803)
tensor(0.0901)
tensor(0.0780)
tensor(0.0875)
tensor(0.0825)
tensor(0.0901)
tensor(0.0867)
tensor(0.0852)
tensor(0.0772)
tensor(0.0770)
tensor(0.0858)
tensor(0.0850)


1201it [00:07, 145.96it/s]

tensor(0.0734)
tensor(0.0749)
tensor(0.0865)
tensor(0.0895)
tensor(0.0815)
tensor(0.0788)
tensor(0.0706)
tensor(0.0774)
tensor(0.0840)
tensor(0.0858)
tensor(0.0896)
tensor(0.0777)
tensor(0.0810)
tensor(0.0807)
tensor(0.0782)
tensor(0.0850)
tensor(0.0811)
tensor(0.0879)
tensor(0.0839)
tensor(0.0862)
tensor(0.0805)
tensor(0.0852)
tensor(0.0833)
tensor(0.0787)
tensor(0.0891)
tensor(0.0774)
tensor(0.0811)
tensor(0.0836)


1234it [00:08, 151.49it/s]

tensor(0.0763)
tensor(0.0854)
tensor(0.0820)
tensor(0.0856)
tensor(0.0790)
tensor(0.0737)
tensor(0.0826)
tensor(0.0758)
tensor(0.0841)
tensor(0.0861)
tensor(0.0883)
tensor(0.0956)
tensor(0.0790)
tensor(0.0834)
tensor(0.0896)
tensor(0.0887)
tensor(0.0854)
tensor(0.0810)
tensor(0.0865)
tensor(0.0766)
tensor(0.0853)
tensor(0.0949)
tensor(0.0738)
tensor(0.0841)
tensor(0.0690)
tensor(0.0741)
tensor(0.0856)
tensor(0.0731)
tensor(0.0916)
tensor(0.0836)
tensor(0.0885)
tensor(0.0874)


1266it [00:08, 149.97it/s]

tensor(0.0753)
tensor(0.0810)
tensor(0.0840)
tensor(0.0830)
tensor(0.0837)
tensor(0.0792)
tensor(0.0808)
tensor(0.0758)
tensor(0.0814)
tensor(0.0743)
tensor(0.0833)
tensor(0.0752)
tensor(0.0807)
tensor(0.0755)
tensor(0.0765)
tensor(0.0799)
tensor(0.0842)
tensor(0.0903)
tensor(0.0872)
tensor(0.0851)
tensor(0.0878)
tensor(0.0771)
tensor(0.0842)
tensor(0.0857)
tensor(0.0783)
tensor(0.0855)
tensor(0.0845)
tensor(0.0755)
tensor(0.0949)
tensor(0.0748)


1298it [00:08, 145.80it/s]

tensor(0.0759)
tensor(0.0793)
tensor(0.0798)
tensor(0.0837)
tensor(0.0820)
tensor(0.0919)
tensor(0.0895)
tensor(0.0942)
tensor(0.0850)
tensor(0.0843)
tensor(0.0814)
tensor(0.0853)
tensor(0.0844)
tensor(0.0897)
tensor(0.0789)
tensor(0.0813)
tensor(0.0811)
tensor(0.0818)
tensor(0.0822)
tensor(0.0823)
tensor(0.0759)
tensor(0.0836)
tensor(0.0903)
tensor(0.0811)
tensor(0.0846)
tensor(0.0790)
tensor(0.0860)
tensor(0.0774)
tensor(0.0792)


1328it [00:08, 146.74it/s]

tensor(0.0875)
tensor(0.0849)
tensor(0.0798)
tensor(0.0811)
tensor(0.0808)
tensor(0.0851)
tensor(0.0775)
tensor(0.0846)
tensor(0.0853)
tensor(0.0791)
tensor(0.0762)
tensor(0.0783)
tensor(0.0773)
tensor(0.0761)
tensor(0.0849)
tensor(0.0767)
tensor(0.0854)
tensor(0.0875)
tensor(0.0850)
tensor(0.0732)
tensor(0.0797)
tensor(0.0932)
tensor(0.0757)
tensor(0.0935)
tensor(0.0837)
tensor(0.0864)
tensor(0.0834)
tensor(0.0859)
tensor(0.0810)
tensor(0.0887)


1359it [00:08, 149.07it/s]

tensor(0.0761)
tensor(0.0777)
tensor(0.0890)
tensor(0.0794)
tensor(0.0781)
tensor(0.0945)
tensor(0.0805)
tensor(0.0792)
tensor(0.0823)
tensor(0.0799)
tensor(0.0870)
tensor(0.0782)
tensor(0.0724)
tensor(0.0796)
tensor(0.0869)
tensor(0.0842)
tensor(0.0803)
tensor(0.0814)
tensor(0.0913)
tensor(0.0806)
tensor(0.0830)
tensor(0.0811)
tensor(0.0753)
tensor(0.0868)
tensor(0.0785)
tensor(0.0835)
tensor(0.0809)
tensor(0.0893)
tensor(0.0868)
tensor(0.0850)
tensor(0.0830)


1391it [00:09, 150.23it/s]

tensor(0.0885)
tensor(0.0721)
tensor(0.0954)
tensor(0.0876)
tensor(0.0805)
tensor(0.0823)
tensor(0.0842)
tensor(0.0914)
tensor(0.0786)
tensor(0.0793)
tensor(0.0886)
tensor(0.0866)
tensor(0.0778)
tensor(0.0974)
tensor(0.0783)
tensor(0.0822)
tensor(0.0798)
tensor(0.0864)
tensor(0.0872)
tensor(0.0759)
tensor(0.0820)
tensor(0.0778)
tensor(0.0795)
tensor(0.0910)
tensor(0.0764)
tensor(0.0835)
tensor(0.0934)
tensor(0.0863)
tensor(0.0735)
tensor(0.0766)
tensor(0.0792)


1423it [00:09, 150.40it/s]

tensor(0.0875)
tensor(0.0834)
tensor(0.0837)
tensor(0.0897)
tensor(0.0808)
tensor(0.0850)
tensor(0.0795)
tensor(0.0930)
tensor(0.0799)
tensor(0.0825)
tensor(0.0819)
tensor(0.0838)
tensor(0.0825)
tensor(0.0863)
tensor(0.0805)
tensor(0.0845)
tensor(0.0810)
tensor(0.0833)
tensor(0.0736)
tensor(0.0861)
tensor(0.0816)
tensor(0.0835)
tensor(0.0718)
tensor(0.0835)
tensor(0.0837)
tensor(0.0812)
tensor(0.0815)
tensor(0.0779)
tensor(0.0785)
tensor(0.0840)
tensor(0.0784)


1455it [00:09, 152.73it/s]

tensor(0.0850)
tensor(0.0840)
tensor(0.0896)
tensor(0.0865)
tensor(0.0863)
tensor(0.0742)
tensor(0.0768)
tensor(0.0844)
tensor(0.0898)
tensor(0.0798)
tensor(0.0851)
tensor(0.0762)
tensor(0.0911)
tensor(0.0788)
tensor(0.0930)
tensor(0.0791)
tensor(0.0945)
tensor(0.0790)
tensor(0.0781)
tensor(0.0839)
tensor(0.0809)
tensor(0.0689)
tensor(0.0873)
tensor(0.0763)
tensor(0.0790)
tensor(0.0775)
tensor(0.0830)
tensor(0.0824)
tensor(0.0869)
tensor(0.0868)
tensor(0.0787)
tensor(0.0822)


1487it [00:09, 151.13it/s]

tensor(0.0792)
tensor(0.0779)
tensor(0.0780)
tensor(0.0747)
tensor(0.0781)
tensor(0.0814)
tensor(0.0880)
tensor(0.0820)
tensor(0.0836)
tensor(0.0889)
tensor(0.0787)
tensor(0.0846)
tensor(0.0795)
tensor(0.0821)
tensor(0.0944)
tensor(0.0912)
tensor(0.0773)
tensor(0.0875)
tensor(0.0799)
tensor(0.0787)
tensor(0.0812)
tensor(0.0889)
tensor(0.0856)
tensor(0.0741)
tensor(0.0771)
tensor(0.0803)
tensor(0.0885)
tensor(0.0821)
tensor(0.0806)
tensor(0.0849)


1519it [00:09, 153.75it/s]

tensor(0.0824)
tensor(0.0836)
tensor(0.0810)
tensor(0.0822)
tensor(0.0833)
tensor(0.0821)
tensor(0.0821)
tensor(0.0823)
tensor(0.0874)
tensor(0.0839)
tensor(0.0787)
tensor(0.0856)
tensor(0.0901)
tensor(0.0795)
tensor(0.0889)
tensor(0.0859)
tensor(0.0820)
tensor(0.0837)
tensor(0.0796)
tensor(0.0937)
tensor(0.0860)
tensor(0.0796)
tensor(0.0782)
tensor(0.0816)
tensor(0.0772)
tensor(0.0835)
tensor(0.0775)
tensor(0.0864)
tensor(0.0811)
tensor(0.0810)
tensor(0.0767)
tensor(0.0789)
tensor(0.0867)


1551it [00:10, 155.85it/s]

tensor(0.0874)
tensor(0.0771)
tensor(0.0754)
tensor(0.0892)
tensor(0.0811)
tensor(0.0812)
tensor(0.0778)
tensor(0.0897)
tensor(0.0843)
tensor(0.0827)
tensor(0.0810)
tensor(0.0820)
tensor(0.0767)
tensor(0.0822)
tensor(0.0850)
tensor(0.0826)
tensor(0.0787)
tensor(0.0854)
tensor(0.0800)
tensor(0.0793)
tensor(0.0819)
tensor(0.0796)
tensor(0.0842)
tensor(0.0796)
tensor(0.0812)
tensor(0.0755)
tensor(0.0849)
tensor(0.0728)
tensor(0.0803)
tensor(0.0804)
tensor(0.0744)
tensor(0.0853)


1568it [00:10, 153.16it/s]


tensor(0.0838)
tensor(0.0790)
tensor(0.0919)
tensor(0.0806)
tensor(0.0894)
tensor(0.0734)
tensor(0.0857)
tensor(0.0942)
tensor(0.0857)
tensor(0.0767)
tensor(0.0886)
tensor(0.0824)
tensor(0.0813)
tensor(0.0773)
tensor(0.0824)


KeyboardInterrupt: 