<a href="https://colab.research.google.com/github/kazu-gor/othello/blob/develop/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def dict_to_cpu(dictionary):
    cpu_dict = {}
    for key, value in dictionary.items():
        if isinstance(value, torch.Tensor):
            cpu_dict[key] = value.cpu()
        elif isinstance(value, dict):
            cpu_dict[key] = dict_to_cpu(value)
        else:
            cpu_dict[key] = value
    return cpu_dict

In [None]:
class AbstractNetwork(ABC, nn.Module):
    def __init__(self):
        super().__init__()
        pass
    @abstractmethod
    def initial_inference(self, observation):
        pass
    
    @abstractmethod
    def recurrent_inference(self, encoded_state, action):
        pass
    
    def get_weights(self):
        return dict_to_cpu(self.state_dict())
    
    def set_weights(self, weights):
        self.load_state_dict(weights)

In [None]:
class MuZeroFullyConnectedNetwork(AbstractNetwork):
    def __init__(self):
        super().__init__()
        
        self.representaton_network = nn.DataParallel(
            
        )

In [None]:
def conv3x3(in_channels, out_chennels, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channels, stride=1):
        super().__init__()
        self.conv1 = conv3x3(num_channels, num_channels, stride)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = conv3x3(num_channels, num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x
        out = F.relu(out)
        return out