In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

In [1]:
class BinaryModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.sign_weight = nn.Linear(in_features=1, out_features=1, bias=False)
        self.exponent_weight = nn.Linear(in_features=11, out_features=1, bias=False)
        self.mantissa_weight = nn.Linear(in_features=52, out_features=1, bias=False)
        
        with torch.no_grad():
            # Weights for the sign bit
            self.sign_weight.weight.copy_(torch.tensor([[-1.0]], dtype=torch.float32))
            
            # Weights for the exponent (2^(e - 1023))
            exponent_powers = torch.tensor([2.0 ** i for i in range(10, -1, -1)], dtype=torch.float32)
            self.exponent_weight.weight.copy_(exponent_powers.unsqueeze(0))
            
            # Weights for the mantissa (2^(-1) to 2^(-52))
            mantissa_powers = torch.tensor([2.0 ** (-i) for i in range(1, 53)], dtype=torch.float32)
            self.mantissa_weight.weight.copy_(mantissa_powers.unsqueeze(0))
        
        # Freeze weights
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, input):
        # Split input into sign, exponent, and mantissa
        sign = input[:, :1]  # First bit
        exponent = input[:, 1:12]  # Next 11 bits
        mantissa = input[:, 12:]  # Remaining 52 bits
        
        # Convert sign
        sign_value = self.sign_weight(sign)
        
        # Convert exponent (subtract 1023 for bias)
        exponent_value = self.exponent_weight(exponent) - 1023
        
        # Convert mantissa (add implicit leading 1)
        mantissa_value = self.mantissa_weight(mantissa) + 1.0
        
        # Compute the final value: (-1)^sign * 2^exponent * mantissa
        decimal_value = torch.pow(2.0, exponent_value) * mantissa_value
        decimal_value = torch.where(sign == 1, -decimal_value, decimal_value)  # Apply sign
        
        return decimal_value


def main():
    model = BinaryModel()

    input = torch.tensor([
                            [0, 1, 0, 0, 0, 0, 0, 0,  # First byte
                            0, 0, 1, 0, 0, 1, 0, 1,  # Second byte
                            0, 0, 0, 0, 0, 0, 0, 0,  # Third byte
                            0, 0, 0, 0, 0, 0, 0, 0,  # Fourth byte
                            0, 0, 0, 0, 0, 0, 0, 0,  # Fifth byte
                            0, 0, 0, 0, 0, 0, 0, 0,  # Sixth byte
                            0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0]  # Seventh and eighth bytes
                        ], dtype=torch.float32)
    
    output = model(input)
    print(output.item())

if __name__ == '__main__':
    main()

NameError: name 'nn' is not defined