In [1]:
import torch
import torch.nn as nn
import copy
from icecream import ic

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Callable, Iterable, Iterator, Optional, Tuple, List, Any, Sequence

@dataclass
class Test:
    a: Sequence[int]
    b: Sequence[float]
    
    def pop(self, size) -> Test:
        
        ret = Test(self.a[:size], self.b[:size])
        
        for field in fields(self):
            setattr(self, field.name, getattr(self, field.name)[size:])
    
        return ret
    
    def __len__(self):
        return len(self.a)
    
    def __add__(self, other):
        ret = Test(None, None)
        for field in fields(self):
            setattr(ret, field.name, getattr(self, field.name) + getattr(other, field.name))
        return ret

    def __iadd__(self, other):
        for field in fields(self):
            setattr(self, field.name, getattr(self, field.name).__iadd__(getattr(other, field.name)))
        return self

    def __iter__(self) -> Iterator:
        return iter([
            getattr(self, field.name)
            for field in fields(self)
        ])

@dataclass
class ChildTest(Test):
    child: str
    
ic(fields(Test))
ic(fields(ChildTest))

class TestIterator:

    def __init__(self, src, size = 1):
        self.src = src
        self.size = size
    
    def __iter__(self):
        return self

    def __next__(self):
        ret = self.src.pop(self.size)
        if len(ret) > 0:
            return ret
        else:
            raise StopIteration

t1 = Test(
    a=[i for i in range(10)],
    b=[i / 10 for i in range(10)],
)
t2 = Test(
    a=[i for i in range(10, 20)],
    b=[i / 10 for i in range(10, 20)],
)

ic| fields(Test): (Field(name='a',type='Sequence[int]',default=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
                   Field(name='b',type='Sequence[float]',default=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD))
ic| fields(ChildTest): (Field(name='a',type='Sequence[int]',default=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
                        Field(name='b',type='Sequence[float]',default=<dataclasses._MISSING_TYPE object at 0x7f5510ed27d0>,def

In [21]:
t1 += t2
t1


Test(a=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], b=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9])

In [23]:
t1

Test(a=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], b=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9])

In [29]:
list_test = [ Test(i, i + 0.5) for i in range(5) ]
a = list(map(list, zip(*list_test)))
t = Test(*list(map(list, zip(*list_test))))
t

Test(a=[0, 1, 2, 3, 4], b=[0.5, 1.5, 2.5, 3.5, 4.5])

In [39]:
a = [0,1,2,3,4,5]
b = a
b = b[2:]
print(a)
print(b)
print(f'mean: {torch.Tensor(a).mean().item()}')
b[slice(4,5)]
d1 = {i: str(i) for i in range(5)}
d2 = {i: str(i) for i in range(5, 10)}
print({**d1, **d2})

[0, 1, 2, 3, 4, 5]
[2, 3, 4, 5]
mean: 2.5
{0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}


In [78]:
m1 = torch.load('outputs/2022-06-21/13-20-32/0.pt')
m3 = torch.load('outputs/2022-06-21/14-18-27/11.pt')
m2 = torch.load('outputs/2022-06-21/14-18-27/iter_2.pt')['model_state_dict']


In [43]:
m1.keys()

odict_keys(['graph_embedding.embedding.weight', 'graph_embedding.conv_0.linear2.weight', 'graph_embedding.conv_0.linear2.bias', 'graph_embedding.conv_0.linear1.weight', 'graph_embedding.convs.0.linear2.weight', 'graph_embedding.convs.0.linear2.bias', 'graph_embedding.convs.0.linear1.weight', 'graph_embedding.convs.1.linear2.weight', 'graph_embedding.convs.1.linear2.bias', 'graph_embedding.convs.1.linear1.weight', 'graph_embedding.convs.2.linear2.weight', 'graph_embedding.convs.2.linear2.bias', 'graph_embedding.convs.2.linear1.weight', 'graph_embedding.convs.3.linear2.weight', 'graph_embedding.convs.3.linear2.bias', 'graph_embedding.convs.3.linear1.weight', 'graph_embedding.convs.4.linear2.weight', 'graph_embedding.convs.4.linear2.bias', 'graph_embedding.convs.4.linear1.weight', 'actor.0.weight', 'actor.0.bias', 'actor.2.weight', 'actor.2.bias', 'critic.0.weight', 'critic.0.bias', 'critic.2.weight', 'critic.2.bias'])

In [79]:
key = 'graph_embedding.embedding.weight'
m3[key] == m2[key]

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, Tru

In [82]:
x = torch.rand(10, 5)
torch.stack([x, x]).shape

torch.Size([2, 10, 5])

In [2]:
class A:
    def __init__(self):
        pass
    
    @staticmethod
    def name() -> str:
        return 'A'

a = A()
A.name()

'A'