In [10]:
:module +Prelude

In [1]:
import Torch
t = ones' [2, 2]
print t

Tensor Float [2,2] [[ 1.0000   ,  1.0000   ],
                    [ 1.0000   ,  1.0000   ]]

In [6]:
import Torch
let a = ones' [3, 3]
b <- randIO' [3, 3]
-- This triggers the underlying C++ matrix multiplication
let c = matmul a b
print c

Tensor Float [3,3] [[ 1.5408   ,  0.7580   ,  1.7522   ],
                    [ 1.5408   ,  0.7580   ,  1.7522   ],
                    [ 1.5408   ,  0.7580   ,  1.7522   ]]

In [16]:
import Torch

-- 1. Initialize the full model structure
-- We use a small vocab size (e.g., 100) just for testing shapes
let vocabSize = 100
model <- initModel vocabSize

-- 2. Extract just the Attention layer from the first block
-- Your model has 'layers', which is a list of TransformerBlocks
let firstBlock = head (layers model)
let mhaLayer = attention firstBlock

-- 3. Create dummy input [Batch=2, Seq=10, Dim=64]
-- (Matches the embedDim=64 in your code)
let dummyInput = ones' [2, 10, 64]

-- 4. Run ONLY the Multi-Head Attention forward pass
-- This lets you see the output shape of just that specific component
let output = forwardMHA mhaLayer dummyInput

print (shape output)
-- Should be [2, 10, 64]


[2,10,64]

In [14]:
{-# LANGUAGE OverloadedStrings #-}
import IHaskell.Display
import qualified Data.Text as T
import qualified Data.Text.Lazy as LT
import Text.Printf (printf)
import qualified Torch as Th

-- A helper to visualize a 2D attention matrix as an HTML Table
visualizeAttention :: Th.Tensor -> IO ()
visualizeAttention t = do
  -- Convert tensor to list of lists (assuming 2D [Seq, Seq])
  let rows = (Th.asValue t :: [[Float]])
  
  -- Helper to generate the HTML for a single cell
  let cell val = 
        let 
           -- Background: Blue with opacity matching the attention weight
           bgStyle = printf "background-color: rgba(0, 0, 255, %.2f)" val :: String
           
           -- Text Color: White if background is dark/intense (> 0.5), Black otherwise
           -- This prevents "White Text on White Background" issues for low values
           textColor = if val > 0.5 then "white" else "black" :: String
           
        in printf "<td style='%s; color: %s; width: 40px; height: 40px; border: 1px solid #ddd; text-align: center; font-size: 12px;'>%.2f</td>" bgStyle textColor val
  
  let rowHtml r = "<tr>" ++ concatMap cell r ++ "</tr>"
  let tableHtml = "<table style='border-collapse: collapse; font-family: sans-serif;'>" ++ concatMap rowHtml rows ++ "</table>"
  
  printDisplay $ Display [html tableHtml]

In [20]:
{-# LANGUAGE RecordWildCards #-}

-- Copy this into a cell to define a debug-capable attention function
import qualified Torch.NN as NN
import qualified Torch.Functional as F
import qualified Torch as Th

-- Returns (Output, AttentionWeights)
forwardMHADebug :: MultiHeadAttention -> Th.Tensor -> (Th.Tensor, Th.Tensor)
forwardMHADebug MultiHeadAttention {..} x =
  let 
      -- 1. Linear Projections
      q = NN.forward mhaLinearQ x
      k = NN.forward mhaLinearK x
      v = NN.forward mhaLinearV x

      headDim = mhaEmbedDim `Prelude.div` mhaHeads
      batch = head (Th.shape x)
      seqLength = Th.shape x !! 1

      -- 2. Reshape
      viewShape = [batch, seqLength, mhaHeads, headDim]
      q' = F.transpose (F.Dim 1) (F.Dim 2) $ Th.reshape viewShape q
      k' = F.transpose (F.Dim 1) (F.Dim 2) $ Th.reshape viewShape k
      v' = F.transpose (F.Dim 1) (F.Dim 2) $ Th.reshape viewShape v

      -- 3. Scores
      kT = F.transpose (F.Dim 2) (F.Dim 3) k'
      scoresRaw = F.matmul q' kT
      dk = Th.asTensor (fromIntegral headDim :: Float)
      scoresScaled = scoresRaw / F.sqrt dk

      -- 4. Mask (Simplified for debug: Optional)
      -- For simple visualization, we can skip the mask or apply it if needed.
      -- If you want to see the triangle, include the masking logic here.
      
      -- 5. Softmax -> THIS IS THE HEATMAP
      attnWeights = F.softmax (F.Dim 3) scoresScaled

      -- 6. Context
      context = F.matmul attnWeights v'
      contextT = F.transpose (F.Dim 1) (F.Dim 2) context
      contextReshaped = Th.reshape [batch, seqLength, mhaEmbedDim] contextT
      
      finalOut = NN.forward mhaLinearOut contextReshaped
   in (finalOut, attnWeights)

In [21]:
-- 1. Run the Debug Forward Pass
let (output, weights) = forwardMHADebug mhaLayer dummyInput

-- Check the shape of weights: Should be [2, 4, 10, 10] 
-- (Batch=2, Heads=4, Seq=10, Seq=10)
print (Th.shape weights)

-- 2. Slice out ONE attention map
-- Select Batch 0
let batch0 = Th.select 0 0 weights 
-- Select Head 0
let head0 = Th.select 0 0 batch0 

-- Shape should now be [10, 10]
print (Th.shape head0)

-- 3. Visualize!
visualizeAttention head0

[2,4,10,10]

[10,10]

0,1,2,3,4,5,6,7,8,9
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1
