In [1]:
import torch
import importlib
import sys
import random
sys.path.append("..")
from module import dataset
importlib.reload(dataset)

<module 'module.dataset' from '/home/argon/openmm/cgschnet/cgschnet/cgschnet/scripts/tests/../module/dataset.py'>

In [2]:
# # Make tensors less verbose
# torch.Tensor.__repr__ = lambda x : f"tensor(shape={list(x.shape)})"

In [3]:
def make_mock_items(n_items, gen_boxes=False, lengths=None):
    if not lengths:
        lengths = [105,101,66,57,85,67,91,61]
    result = []
    for k in range(n_items):
        coord = [torch.zeros((i, 3),dtype=torch.float) for i in lengths]
        embed = [torch.zeros((i),dtype=torch.long) for i in lengths]
        force = [i*torch.ones((i, 3),dtype=torch.float) for i in lengths]
        if gen_boxes:
            boxes = [torch.zeros((1, 3, 3),dtype=torch.long) for i in lengths]
        else:
            boxes = [torch.tensor([],dtype=torch.long) for i in lengths]
        result.append({"pos":coord, "z":embed, "force":force, "box":boxes})
    return result

In [4]:
collate_fn = dataset.ProteinBatchCollate(500)

In [5]:
batched = collate_fn(make_mock_items(1))
assert len(batched) == 2
assert batched[0]["pos"].shape == (1, 481, 3)
assert batched[0]["z"].shape == (1, 481)
assert batched[0]["force"].shape == (1, 481, 3)
assert batched[0]["box"].shape == (1, 0)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101, 66, 57, 85, 67]
assert batched[1]["lengths"] == [91, 61]

In [6]:
batched = collate_fn(make_mock_items(2))
assert len(batched) == 3
assert batched[0]["pos"].shape == (2, 206, 3)
assert batched[0]["z"].shape == (2, 206)
assert batched[0]["force"].shape == (2, 206, 3)
assert batched[0]["box"].shape == (2, 0)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101]
assert batched[1]["lengths"] == [66, 57, 85]
assert batched[2]["lengths"] == [67, 91, 61]

In [7]:
assertion_ok = False
try:
    collate_fn(make_mock_items(8))
except AssertionError as e:
    assertion_ok = True
    print("OK:", str(e))
assert assertion_ok, "Failed to detect too large a batch"

OK: Molecule 0 is too large (840x8)>500


In [8]:
batched = collate_fn(make_mock_items(1, True))
assert len(batched) == 2
assert batched[0]["pos"].shape == (1, 481, 3)
assert batched[0]["z"].shape == (1, 481)
assert batched[0]["force"].shape == (1, 481, 3)
assert batched[0]["box"].shape == (1, len(batched[0]["lengths"]), 3, 3)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101, 66, 57, 85, 67]
assert batched[1]["lengths"] == [91, 61]

In [9]:
batched = collate_fn(make_mock_items(2, True))
assert len(batched) == 3
assert batched[0]["pos"].shape == (2, 206, 3)
assert batched[0]["z"].shape == (2, 206)
assert batched[0]["force"].shape == (2, 206, 3)
assert batched[0]["box"].shape == (2, len(batched[0]["lengths"]), 3, 3)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101]
assert batched[1]["lengths"] == [66, 57, 85]
assert batched[2]["lengths"] == [67, 91, 61]

In [10]:
collate_fn_inf = dataset.ProteinBatchCollate(None)

In [11]:
batched = collate_fn_inf(make_mock_items(2, True))
assert len(batched) == 1
assert batched[0]["pos"].shape == (2, 633, 3)
assert batched[0]["z"].shape == (2, 633)
assert batched[0]["force"].shape == (2, 633, 3)
assert batched[0]["box"].shape == (2, len(batched[0]["lengths"]), 3, 3)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101, 66, 57, 85, 67, 91, 61]

In [12]:
batched = collate_fn_inf(make_mock_items(8, True))
assert len(batched) == 1
assert batched[0]["pos"].shape == (8, 633, 3)
assert batched[0]["z"].shape == (8, 633)
assert batched[0]["force"].shape == (8, 633, 3)
assert batched[0]["box"].shape == (8, len(batched[0]["lengths"]), 3, 3)
assert sum(len(i["lengths"]) for i in batched) == 8
assert batched[0]["lengths"] == [105, 101, 66, 57, 85, 67, 91, 61]

Test sub batch grouping logic

In [23]:
import numpy as np
from torch import nn

In [70]:
# xTODO: Make shuffle fixed
# xTODO: Test Kevin's set of protein sizes
# xTODO: Test large and small batch sizes
# xTODO: Test at least one more collate_fn size

rng = np.random.default_rng(424242)
lengths_a0 = [105,101,66,57,85,67,91,61]
lengths_a1 = rng.permutation(lengths_a0).tolist()
lengths_a2 = rng.permutation(lengths_a0).tolist()
lengths_b0 = [164, 105, 125, 131, 174]
lengths_b1 = rng.permutation(lengths_b0).tolist()
lengths_b2 = rng.permutation(lengths_b0).tolist()
lengths_c0 = [500, 50, 25, 500, 25]

lengths_a0, lengths_a1, lengths_a2, lengths_b0, lengths_b1, lengths_b2

([105, 101, 66, 57, 85, 67, 91, 61],
 [101, 61, 66, 85, 91, 57, 105, 67],
 [66, 61, 105, 101, 67, 85, 57, 91],
 [164, 105, 125, 131, 174],
 [131, 164, 174, 105, 125],
 [125, 174, 164, 105, 131])

In [71]:
def test_batch(batch):
    criterion = nn.MSELoss(reduction="mean")
    total_batch_size = sum([i["force"].numel() for i in batch])
    # print(total_batch_size)
    batch_loss = 0
    for sub_batch in batch:
        sub_batch_size = sub_batch["force"].numel()
        # print(sub_batch_size)
        loss = criterion(sub_batch["force"], torch.zeros_like(sub_batch["force"])) * (sub_batch_size / total_batch_size)
        batch_loss += loss.item() # Standin for the "backwards" loss
    return batch_loss

In [89]:
collate_fn_1000 = dataset.ProteinBatchCollate(1000)
collate_fn_2000 = dataset.ProteinBatchCollate(2000)
collate_fn_10000 = dataset.ProteinBatchCollate(10000)
collate_fn_inf = dataset.ProteinBatchCollate(None)

print("Batch 4")
for lengths in [lengths_a0, lengths_a1, lengths_a2, lengths_b0, lengths_b1, lengths_b2]:
    batch = make_mock_items(4, False, lengths=lengths)
    inf_loss = test_batch(collate_fn_inf(batch))
    batched_loss = test_batch(collate_fn_1000(batch))
    print("% err =", 100*(batched_loss - inf_loss)/inf_loss)

print("Batch 10")
for lengths in [lengths_a0, lengths_a1, lengths_a2, lengths_b0, lengths_b1, lengths_b2]:
    batch = make_mock_items(10, False, lengths=lengths)
    inf_loss = test_batch(collate_fn_inf(batch))
    batched_loss = test_batch(collate_fn_2000(batch))
    print("% err =", 100*(batched_loss - inf_loss)/inf_loss)

print("Batch 50")
for lengths in [lengths_a0, lengths_a1, lengths_a2, lengths_b0, lengths_b1, lengths_b2]:
    batch = make_mock_items(50, False, lengths=lengths)
    inf_loss = test_batch(collate_fn_inf(batch))
    batched_loss = test_batch(collate_fn_10000(batch))
    print("% err =", 100*(batched_loss - inf_loss)/inf_loss)

print("Batch Large...")
for i in [10,20,25]:
    for lengths in [lengths_c0]:
        print("  Batch",i)
        batch = make_mock_items(i, False, lengths=lengths)
        # print([i["lengths"] for i in collate_fn_10000(batch)])
        inf_loss = test_batch(collate_fn_inf(batch))
        batched_loss = test_batch(collate_fn_10000(batch))
        print("% err =", 100*(batched_loss - inf_loss)/inf_loss)

Batch 4
% err = -5.087987586938444e-06
% err = -3.3919917246256295e-06
% err = -6.783982989026975e-06
% err = -5.674174037721364e-06
% err = 3.4045047317168564e-06
% err = -1.2483181749679069e-05
Batch 10
% err = 1.2719971556108405e-05
% err = 5.087987932106703e-06
% err = -1.5263961725310696e-05
% err = 7.376426918719855e-06
% err = -1.7022522113164093e-06
% err = 7.376426918719855e-06
Batch 50
% err = 6.359984915133379e-06
% err = 1.6959959773689011e-06
% err = 5.087987932106703e-06
% err = 1.0780931650436712e-05
% err = 1.0780931650436712e-05
% err = 1.0780931650436712e-05
Batch Large...
  Batch 10
% err = 0.0
  Batch 20
% err = -5.486466916110714e-06
  Batch 25
% err = 1.218917210366394e-06
