In [1]:
import torch
from torch.nn import functional as F
import einops
from huggingface_hub import hf_hub_download
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader  
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

class LCA(torch.nn.Module):
    def __init__(self, input_dim, dict_size, lambd=0.1, lr=None, max_iter=100, 
                 fac=0.5, tol=1e-6, device='cuda', verbose=False):
        super().__init__()
        self.input_dim = input_dim
        self.dict_size = dict_size
        self.D = torch.nn.Parameter(torch.randn(input_dim, dict_size))
        self.normalize_dictionary()
        self.lambd = lambd
        self.lr = lr if lr is not None else 1.0 / dict_size
        self.max_iter = max_iter
        self.facs = [fac, 1/fac]
        self.tol = tol
        self.device = device
        self.verbose = verbose

    def normalize_dictionary(self):
        with torch.no_grad():
            self.D.data = F.normalize(self.D.data, dim=0)

    def inference(self, x):
        batch_size = x.shape[0]
        u = torch.zeros(batch_size, self.dict_size, device=self.device)
        a = torch.relu(u)
        
        if isinstance(self.lr, float):
            lr = torch.full((batch_size,), self.lr, device=self.device)
        else:
            lr = self.lr.clone()
        
        best_loss = torch.full((batch_size,), float('inf'), device=self.device)
        
        for iter_idx in range(self.max_iter):
            rec = torch.matmul(a, self.D.T)
            
            recon_error = torch.mean((x - rec) ** 2, dim=1)
            l1_penalty = self.lambd * torch.mean(torch.abs(a), dim=1)
            loss = recon_error + l1_penalty
            
            if self.verbose and iter_idx % 10 == 0:
                avg_loss = loss.mean().item()
                avg_recon = recon_error.mean().item()
                avg_l1 = l1_penalty.mean().item()
                sparsity = (a > 0).float().mean().item() * 100
                print(f"LCA Iteration {iter_idx}: "
                      f"Loss = {avg_loss:.6f}, "
                      f"Recon = {avg_recon:.6f}, "
                      f"L1 = {avg_l1:.6f}, "
                      f"Sparsity = {sparsity:.1f}%")
            
            if torch.max(best_loss - loss) < self.tol:
                if self.verbose:
                    print(f"Converged at iteration {iter_idx}")
                break
            
            best_loss = loss
            du = torch.matmul((rec - x), self.D) + self.lambd
            
            losses = []
            u_candidates = []
            a_candidates = []
            
            for fac in self.facs:
                lr_expanded = lr.view(-1, 1)
                u_new = u - du * (lr_expanded * fac)
                a_new = torch.relu(u_new)
                rec_new = torch.matmul(a_new, self.D.T)
                
                recon_error_new = torch.mean((x - rec_new) ** 2, dim=1)
                l1_penalty_new = self.lambd * torch.mean(torch.abs(a_new), dim=1)
                loss_new = recon_error_new + l1_penalty_new
                
                losses.append(loss_new)
                u_candidates.append(u_new)
                a_candidates.append(a_new)
            
            losses = torch.stack(losses, dim=0)
            best_idx = torch.argmin(losses, dim=0)
            
            u = torch.stack([
                u_candidates[best_idx[i]][i] for i in range(batch_size)
            ])
            
            a = torch.relu(u)
            lr = lr * torch.tensor([self.facs[idx.item()] for idx in best_idx], 
                                 device=self.device)
        
        return a, rec, recon_error.mean(), l1_penalty.mean()

    def forward(self, x):
        self.normalize_dictionary()
        a = self.inference(x)
        return a

def normalized_mse(recon, xs):
    mse = F.mse_loss(recon, xs)
    mean_xs = xs.mean(dim=0, keepdim=True).expand_as(xs)
    mse_mean = F.mse_loss(mean_xs, xs)
    epsilon = 1e-8
    return mse / (mse_mean + epsilon)

In [2]:
# Set parameters
repo_name = "charlieoneill/sparse-coding-lca"
model_filename = "lca_model.pth"
input_dim = 768
dict_size = 16896
device = 'cpu'
layer = 9
lambd = 0.1  # Sparsity parameter

print("Loading model...")
# Download and load the model
model_path = hf_hub_download(repo_id=repo_name, filename=model_filename)

# Initialize model
lca = LCA(
    input_dim=input_dim,
    dict_size=dict_size,
    lambd=lambd,
    max_iter=100,
    device=device,
    verbose=True
)

# Load saved weights
lca.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
lca.eval()
print("Model loaded successfully")

# Load transformer and activation store
hook_point = "blocks.8.hook_resid_pre"
saes, _ = get_gpt2_res_jb_saes(hook_point)
sparse_autoencoder = saes[hook_point]
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device
sparse_autoencoder.cfg.hook_point = f"blocks.{layer}.attn.hook_z"
sparse_autoencoder.cfg.store_batch_size = 64

loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg)
transformer_model, _, activation_store = loader.load_sae_training_group_session()

print("Running inference...")
# Get activations and run inference
batch_tokens = activation_store.get_batch_tokens().to(device)

with torch.no_grad():
    _, cache = transformer_model.run_with_cache(batch_tokens)
    X = cache["resid_pre", layer]
    X = einops.rearrange(X, "batch pos d_model -> (batch pos) d_model")
    X = X[:64]  # Ensure batch size of 64
    del cache
    
    # Run LCA inference
    S_, X_, recon_loss, l1_loss = lca.inference(X)
    
    # Calculate sparsity
    sparsity = (S_ > 0).float().mean().item() * 100
    
    print("\nResults:")
    print(f"Reconstruction loss: {recon_loss.item():.6f}")
    print(f"L1 loss: {l1_loss.item():.6f}")
    print(f"Sparsity: {sparsity:.1f}%")
    print(f"Dictionary max/min/mean: {lca.D.max():.3f}/{lca.D.min():.3f}/{lca.D.mean():.3f}")
    
    # Optional: analyze feature usage
    feature_usage = (S_ > 0).float().mean(dim=0)
    print(f"\nFeature usage statistics:")
    print(f"Mean feature usage: {feature_usage.mean().item():.3f}")
    print(f"Std feature usage: {feature_usage.std().item():.3f}")
    print(f"Dead features (never activated): {(feature_usage == 0).sum().item()}")

Loading model...


lca_model.pth:   0%|          | 0.00/51.9M [00:00<?, ?B/s]

Model loaded successfully


100%|██████████| 1/1 [00:00<00:00,  1.01it/s]


Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Running inference...
LCA Iteration 0: Loss = 220.008682, Recon = 220.008682, L1 = 0.000000, Sparsity = 0.0%
LCA Iteration 10: Loss = 116.371948, Recon = 116.361961, L1 = 0.009995, Sparsity = 30.9%
LCA Iteration 20: Loss = 29.948868, Recon = 29.928413, L1 = 0.020456, Sparsity = 22.3%
LCA Iteration 30: Loss = 9.150305, Recon = 9.126734, L1 = 0.023570, Sparsity = 21.0%
LCA Iteration 40: Loss = 2.856406, Recon = 2.831342, L1 = 0.025064, Sparsity = 20.0%
LCA Iteration 50: Loss = 0.950766, Recon = 0.925015, L1 = 0.025751, Sparsity = 19.2%
LCA Iteration 60: Loss = 0.306013, Recon = 0.279994, L1 = 0.026019, Sparsity = 18.3%
LCA Iteration 70: Loss = 0.134977, Recon = 0.108981, L1 = 0.025996, Sparsity = 17.5%
LCA Iteration 80: Loss = 0.070541, Recon = 0.044646, L1 = 0.025895, Sparsity = 16.8%
LCA Iteration 90: Loss = 0.051264, Recon = 0.025491, L1 = 0.025774, Sparsity = 16.3%

Results:
Reconstruction loss: 0.018980
L1 loss: 0.025655
Sparsity: 15.8%
Dictionary max/min/mean: 0.187/-0.192/-0.000

F

## LCA

In [9]:
import torch
from huggingface_hub import hf_hub_download
import einops
import numpy as np
import yaml
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from datasets import load_dataset
import torch.nn as nn
from torch.nn import functional as F

# ----------------------------- Configuration -----------------------------

# Parameters
REPO_NAME = "charlieoneill/sparse-coding-lca"
MODEL_FILENAME = "lca_model.pth"
INPUT_DIM = 768
DICT_SIZE = 16896
SCORES_PATH = "lca_scores.npy"
FEATURE_INDICES = list(range(100))  # Example: analyze first 100 dictionary elements
TOP_K = 10
LAYER = 9
CONFIG_PATH = "config.yaml"

def get_device():
    """Get the appropriate device (CUDA if available, else CPU)."""
    if torch.cuda.is_available():
        try:
            # Test CUDA availability
            torch.zeros(1).cuda()
            return 'cuda'
        except:
            return 'cpu'
    return 'cpu'

class LCA(nn.Module):
    def __init__(self, input_dim, dict_size, lambd=0.1, lr=None, max_iter=100, 
                 fac=0.5, tol=1e-6, verbose=False):
        super().__init__()
        self.input_dim = input_dim
        self.dict_size = dict_size
        self.D = nn.Parameter(torch.randn(input_dim, dict_size))
        self.normalize_dictionary()
        self.lambd = lambd
        self.lr = lr if lr is not None else 1.0 / dict_size
        self.max_iter = max_iter
        self.facs = [fac, 1/fac]
        self.tol = tol
        self.verbose = verbose

    def normalize_dictionary(self):
        with torch.no_grad():
            self.D.data = F.normalize(self.D.data, dim=0)

    def inference(self, x):
        device = x.device
        batch_size = x.shape[0]
        u = torch.zeros(batch_size, self.dict_size, device=device)
        a = torch.relu(u)
        
        if isinstance(self.lr, float):
            lr = torch.full((batch_size,), self.lr, device=device)
        else:
            lr = self.lr.clone().to(device)
        
        best_loss = torch.full((batch_size,), float('inf'), device=device)
        
        # Track optimization metrics
        loss_history = []
        recon_history = []
        l1_history = []
        sparsity_history = []
        lr_history = []
        
        print(f"\nStarting LCA inference on tensor of shape {x.shape}")
        print(f"Initial lr: {lr[0].item():.6f}")
        
        for iter_idx in range(self.max_iter):
            rec = torch.matmul(a, self.D.T)
            recon_error = torch.mean((x - rec) ** 2, dim=1)
            l1_penalty = self.lambd * torch.mean(torch.abs(a), dim=1)
            loss = recon_error + l1_penalty
            
            # Calculate current metrics
            avg_loss = loss.mean().item()
            avg_recon = recon_error.mean().item()
            avg_l1 = l1_penalty.mean().item()
            sparsity = (a > 0).float().mean().item() * 100
            avg_lr = lr.mean().item()
            
            # Store history
            loss_history.append(avg_loss)
            recon_history.append(avg_recon)
            l1_history.append(avg_l1)
            sparsity_history.append(sparsity)
            lr_history.append(avg_lr)
            
            if iter_idx % 10 == 0 or iter_idx == self.max_iter - 1:
                print(f"\nIteration {iter_idx}:")
                print(f"  Loss: {avg_loss:.6f}")
                print(f"  Reconstruction Error: {avg_recon:.6f}")
                print(f"  L1 Penalty: {avg_l1:.6f}")
                print(f"  Sparsity: {sparsity:.1f}%")
                print(f"  Learning Rate: {avg_lr:.6f}")
                print(f"  Active Features: {int((a > 0).sum().item() / batch_size)}/{self.dict_size}")
                
                # Print distribution of activations if verbose
                if self.verbose:
                    with torch.no_grad():
                        active_vals = a[a > 0]
                        if len(active_vals) > 0:
                            print(f"  Activation stats:")
                            print(f"    Mean: {active_vals.mean().item():.6f}")
                            print(f"    Std: {active_vals.std().item():.6f}")
                            print(f"    Max: {active_vals.max().item():.6f}")
                            print(f"    Min: {active_vals.min().item():.6f}")
            
            if torch.max(best_loss - loss) < self.tol:
                print(f"\nConverged at iteration {iter_idx}")
                print(f"Final sparsity: {sparsity:.1f}%")
                print(f"Final reconstruction error: {avg_recon:.6f}")
                break
            
            best_loss = loss
            du = torch.matmul((rec - x), self.D) + self.lambd
            
            losses = []
            u_candidates = []
            a_candidates = []
            
            for fac in self.facs:
                lr_expanded = lr.view(-1, 1)
                u_new = u - du * (lr_expanded * fac)
                a_new = torch.relu(u_new)
                rec_new = torch.matmul(a_new, self.D.T)
                
                recon_error_new = torch.mean((x - rec_new) ** 2, dim=1)
                l1_penalty_new = self.lambd * torch.mean(torch.abs(a_new), dim=1)
                loss_new = recon_error_new + l1_penalty_new
                
                losses.append(loss_new)
                u_candidates.append(u_new)
                a_candidates.append(a_new)
            
            losses = torch.stack(losses, dim=0)
            best_idx = torch.argmin(losses, dim=0)
            
            u = torch.stack([
                u_candidates[best_idx[i]][i] for i in range(batch_size)
            ])
            
            a = torch.relu(u)
            lr = lr * torch.tensor([self.facs[idx.item()] for idx in best_idx], 
                                device=device)
            
            # Print learning rate adaptation info if verbose
            if self.verbose and (iter_idx % 10 == 0):
                fac_counts = torch.bincount(best_idx, minlength=len(self.facs))
                print("\nLearning rate adaptation:")
                for i, count in enumerate(fac_counts):
                    print(f"  Factor {self.facs[i]}: chosen {count.item()} times")
        
        # Print final statistics
        print("\nOptimization complete:")
        print(f"Initial loss: {loss_history[0]:.6f}")
        print(f"Final loss: {loss_history[-1]:.6f}")
        print(f"Loss reduction: {(1 - loss_history[-1]/loss_history[0])*100:.1f}%")
        print(f"Initial sparsity: {sparsity_history[0]:.1f}%")
        print(f"Final sparsity: {sparsity_history[-1]:.1f}%")
        print(f"Initial lr: {lr_history[0]:.6f}")
        print(f"Final lr: {lr_history[-1]:.6f}")
        
        return a

    def forward(self, x):
        self.normalize_dictionary()
        self.D = self.D.to(x.device)
        a = self.inference(x)
        return a

# ----------------------------- Helper Functions -----------------------------

def load_transformer_model(model_name: str = 'gpt2-small') -> HookedTransformer:
    """Load the transformer model."""
    model = HookedTransformer.from_pretrained(model_name)
    return model.cpu()

def load_tokenized_data(max_length: int = 128, batch_size: int = 64, take_size: int = 102400) -> torch.Tensor:
    """Load and tokenize the OpenWebText dataset."""
    def tokenize_and_concatenate(dataset, tokenizer, streaming=False, max_length=1024, column_name="text", add_bos_token=True):
        for key in dataset.features:
            if key != column_name:
                dataset = dataset.remove_columns(key)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "<PAD>"})
        seq_len = max_length - 1 if add_bos_token else max_length

        def tokenize_function(examples):
            text = examples[column_name]
            full_text = tokenizer.eos_token.join(text)
            num_chunks = 20
            chunk_length = (len(full_text) - 1) // num_chunks + 1
            chunks = [full_text[i * chunk_length: (i + 1) * chunk_length] for i in range(num_chunks)]
            tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
            tokens = tokens[tokens != tokenizer.pad_token_id]
            num_tokens = len(tokens)
            num_batches = num_tokens // seq_len
            tokens = tokens[: seq_len * num_batches]
            tokens = einops.rearrange(tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len)
            if add_bos_token:
                prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
                tokens = np.concatenate([prefix, tokens], axis=1)
            return {"tokens": tokens}

        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=[column_name])
        return tokenized_dataset

    transformer_model = load_transformer_model()
    dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
    dataset = dataset.shuffle(seed=42, buffer_size=10_000)
    tokenized_owt = tokenize_and_concatenate(dataset, transformer_model.tokenizer, max_length=max_length, streaming=True)
    tokenized_owt = tokenized_owt.shuffle(42)
    tokenized_owt = tokenized_owt.take(take_size)
    owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
    owt_tokens_torch = torch.tensor(owt_tokens)
    return owt_tokens_torch

def load_lca(repo_name: str, model_filename: str, input_dim: int, dict_size: int, device: str) -> torch.nn.Module:
    """Load the LCA model from HuggingFace Hub."""
    model_path = hf_hub_download(repo_id=repo_name, filename=model_filename)
    model = LCA(input_dim=input_dim, dict_size=dict_size)
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model.to(device)

def compute_scores(lca: torch.nn.Module, transformer_model: HookedTransformer, 
                  owt_tokens_torch: torch.Tensor, layer: int, device: str) -> np.ndarray:
    """Compute activation scores for each dictionary element."""
    print(f"Computing scores using device: {device}")
    scores = []
    batch_size = 64
    
    for i in tqdm(range(0, owt_tokens_torch.shape[0], batch_size), desc="Computing scores"):
        batch_tokens = owt_tokens_torch[i:i + batch_size].to(device)
        
        with torch.no_grad():
            _, cache = transformer_model.run_with_cache(
                batch_tokens, 
                stop_at_layer=layer + 1, 
                names_filter=None
            )
            X = cache["resid_pre", layer].to(device)
            X = einops.rearrange(X, "batch pos d_model -> (batch pos) d_model")
            del cache
            
            # Get LCA activations
            activations = lca(X)
            
            # Reshape activations to match the format we need
            scores_reshaped = einops.rearrange(
                activations, 
                "(b pos) n -> b n pos", 
                pos=batch_tokens.shape[1]
            ).cpu().numpy().astype(np.float16)
            
            scores.append(scores_reshaped)

    scores = np.concatenate(scores, axis=0)
    np.save(SCORES_PATH, scores)
    return scores

def get_topk_bottomk_logits(dict_element_index: int, lca: torch.nn.Module, 
                           transformer_model: HookedTransformer, k: int = TOP_K) -> tuple:
    """Get top-k and bottom-k logits for a dictionary element."""
    dict_vector = lca.D.data[:, dict_element_index]
    W_U = transformer_model.W_U
    logits = einops.einsum(W_U, dict_vector, "d_model vocab, d_model -> vocab")
    top_k_logits = logits.topk(k).indices
    bottom_k_logits = logits.topk(k, largest=False).indices
    top_k_tokens = [transformer_model.to_string(x.item()) for x in top_k_logits]
    bottom_k_tokens = [transformer_model.to_string(x.item()) for x in bottom_k_logits]
    return top_k_tokens, bottom_k_tokens

In [10]:
# Determine device
device = get_device()
print(f"Using device: {device}")

# Load configuration
config = yaml.safe_load(open(CONFIG_PATH))

# Load models
lca = load_lca(REPO_NAME, MODEL_FILENAME, INPUT_DIM, DICT_SIZE, device)
transformer_model = load_transformer_model().to(device)

# Load or compute scores
try:
    scores = np.load(SCORES_PATH)
    print("Loaded pre-computed scores.")
except FileNotFoundError:
    print(f"Scores file not found at {SCORES_PATH}. Computing scores...")
    owt_tokens_torch = load_tokenized_data()
    scores = compute_scores(lca, transformer_model, owt_tokens_torch, LAYER, device)

# Load tokenized data
owt_tokens_torch = load_tokenized_data()

# Calculate top examples for each dictionary element
for feature_idx in FEATURE_INDICES:
    print(f"\nAnalyzing dictionary element {feature_idx}")
    
    # Get top activating examples
    feature_scores = scores[:, feature_idx, :]
    flat_scores = feature_scores.flatten()
    top_k_indices = flat_scores.argsort()[-TOP_K:][::-1]
    top_k_scores = flat_scores[top_k_indices]
    
    # Get batch and sequence indices
    top_k_batch_indices = top_k_indices // feature_scores.shape[1]
    top_k_seq_indices = top_k_indices % feature_scores.shape[1]
    
    # Get corresponding tokens and scores
    top_k_tokens = [owt_tokens_torch[batch_idx].tolist() for batch_idx in top_k_batch_indices]
    top_k_tokens_str = [[transformer_model.to_string(x) for x in token_seq] 
                        for token_seq in top_k_tokens]
    
    # Get logits information
    top_logits, bottom_logits = get_topk_bottomk_logits(feature_idx, lca, transformer_model)
    
    print(f"\nTop activating tokens for dictionary element {feature_idx}:")
    for i, (tokens, score) in enumerate(zip(top_k_tokens_str, top_k_scores)):
        print(f"\nExample {i+1} (activation: {score:.4f}):")
        print("".join(tokens))
    
    print(f"\nTop boosted tokens: {', '.join(top_logits)}")
    print(f"Bottom boosted tokens: {', '.join(bottom_logits)}")

Using device: cpu
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu
Moving model to device:  cpu
Scores file not found at lca_scores.npy. Computing scores...
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu


Token indices sequence length is longer than the specified maximum sequence length for this model (73252 > 1024). Running this sequence through the model will result in indexing errors


Computing scores using device: cpu


Computing scores:   0%|          | 0/1600 [00:00<?, ?it/s]


Starting LCA inference on tensor of shape torch.Size([8192, 768])
Initial lr: 0.000059

Iteration 0:
  Loss: 119.724327
  Reconstruction Error: 119.724327
  L1 Penalty: 0.000000
  Sparsity: 0.0%
  Learning Rate: 0.000059
  Active Features: 0/16896

Iteration 10:
  Loss: 61.526257
  Reconstruction Error: 61.517220
  L1 Penalty: 0.009039
  Sparsity: 31.7%
  Learning Rate: 0.060606
  Active Features: 5361/16896

Iteration 20:
  Loss: 15.271048
  Reconstruction Error: 15.252929
  L1 Penalty: 0.018119
  Sparsity: 22.6%
  Learning Rate: 0.060445
  Active Features: 3811/16896

Iteration 30:
  Loss: 4.670774
  Reconstruction Error: 4.650534
  L1 Penalty: 0.020240
  Sparsity: 21.3%
  Learning Rate: 0.060728
  Active Features: 3593/16896

Iteration 40:
  Loss: 1.510661
  Reconstruction Error: 1.489450
  L1 Penalty: 0.021212
  Sparsity: 20.3%
  Learning Rate: 0.065844
  Active Features: 3434/16896

Iteration 50:
  Loss: 0.471220
  Reconstruction Error: 0.449587
  L1 Penalty: 0.021633
  Sparsity:

KeyboardInterrupt: 

In [43]:
import random
from typing import List, Dict, Tuple

def group_remaining_players(remaining_males: List[Dict], remaining_females: List[Dict]) -> List[List[Dict]]:
    """
    Group remaining players into teams of 2 or 3.
    """
    remaining_teams = []
    all_remaining = remaining_males + remaining_females
    total_remaining = len(all_remaining)
    
    if total_remaining == 0:
        return []
    elif total_remaining <= 4:
        # If 4 or fewer players, put them all in one team
        remaining_teams.append(all_remaining)
    else:
        # For 5 or more players, create teams of 3 and 2
        num_players = len(all_remaining)
        if num_players % 3 == 0:
            # If divisible by 3, create all teams of 3
            for i in range(0, num_players, 3):
                remaining_teams.append(all_remaining[i:i+3])
        else:
            # Create as many teams of 3 as possible, then one team of 2
            num_teams_of_three = num_players // 3
            players_in_teams_of_three = num_teams_of_three * 3
            
            # Create teams of 3
            for i in range(0, players_in_teams_of_three, 3):
                remaining_teams.append(all_remaining[i:i+3])
            
            # Create final team with remaining players
            if players_in_teams_of_three < num_players:
                remaining_teams.append(all_remaining[players_in_teams_of_three:])
    
    return remaining_teams

def create_teams(participants: List[Dict[str, str]], seed: int = 20) -> List[List[Dict[str, str]]]:
    """
    Create primary teams of 3 (2M+1F) and handle remaining players.
    """
    # Set random seed for reproducibility
    random.seed(seed)
    
    # Filter out uncertain participants and separate by gender
    certain_participants = [p for p in participants if '?' not in p['nickname']]
    males = [p for p in certain_participants if p['gender'] == 'M']
    females = [p for p in certain_participants if p['gender'] == 'F']
    
    # Shuffle both lists
    random.shuffle(males)
    random.shuffle(females)
    
    # Calculate number of complete teams possible
    num_possible_teams = min(len(females), len(males) // 2)
    
    primary_teams = []
    # Create complete teams (2 males + 1 female)
    for i in range(num_possible_teams):
        team = [
            females[i],  # One female
            males.pop(0),  # First male
            males.pop(0)   # Second male
        ]
        primary_teams.append(team)
    
    # Handle remaining players
    remaining_males = males
    remaining_females = females[num_possible_teams:]
    remaining_teams = group_remaining_players(remaining_males, remaining_females)
    
    return primary_teams, remaining_teams

# Create complete participant list
participants = [
    {"name": "Poswell", "nickname": "Anteater", "gender": "M"},
    {"name": "Sengupta", "nickname": "George", "gender": "M"},
    {"name": "Chuck", "nickname": "Albert II", "gender": "M"},
    {"name": "Loki", "nickname": "Sandy", "gender": "M"},
    {"name": "Dicko", "nickname": "Trust Fund Baby", "gender": "M"},
    {"name": "Jackman", "nickname": "Scuba Diver", "gender": "M"},
    {"name": "Jack Wu", "nickname": "My Quant", "gender": "M"},
    {"name": "Sieb", "nickname": "Long Distance King", "gender": "M"},
    {"name": "Georgie", "nickname": "Petrol Burner", "gender": "F"},
    {"name": "Bean", "nickname": "?", "gender": "F"},
    {"name": "Wanless", "nickname": "the trashman", "gender": "M"},
    {"name": "Jacqui", "nickname": "Sengupta's First", "gender": "F"},
    {"name": "HPL", "nickname": "LPH", "gender": "M"},
    {"name": "Indigo", "nickname": "Indiiiia", "gender": "F"},
    {"name": "Audrey", "nickname": "Astronaut", "gender": "F"},
    {"name": "Priscilla", "nickname": "Pastaless", "gender": "F"},
    {"name": "Tom", "nickname": "East India Trading Company", "gender": "M"},
    {"name": "Scollingwood", "nickname": "Carlton", "gender": "M"},
    {"name": "Gemma", "nickname": "Sengupta's Last", "gender": "F"}
]

# Generate teams
primary_teams, remaining_teams = create_teams(participants)

# Print all teams
print("\nPrimary Teams (2M+1F):")
print("-" * 50)
for i, team in enumerate(primary_teams, 1):
    print(f"\nTeam {i}:")
    for player in team:
        print(f"- {player['name']} ({player['nickname']})")

if remaining_teams:
    print("\nAdditional Teams:")  # Changed from "Remaining Teams" for clarity
    print("-" * 50)
    for i, team in enumerate(remaining_teams, len(primary_teams) + 1):
        print(f"\nTeam {i}:")
        for player in team:
            print(f"- {player['name']} ({player['nickname']})")

print("\nNote: Finn is also serving as Caddy")


Primary Teams (2M+1F):
--------------------------------------------------

Team 1:
- Indigo (Indiiiia)
- Jackman (Scuba Diver)
- Loki (Sandy)

Team 2:
- Jacqui (Sengupta's First)
- Poswell (Anteater)
- Sieb (Long Distance King)

Team 3:
- Gemma (Sengupta's Last)
- Wanless (the trashman)
- HPL (LPH)

Team 4:
- Priscilla (Pastaless)
- Jack Wu (My Quant)
- Sengupta (George)

Team 5:
- Georgie (Petrol Burner)
- Dicko (Trust Fund Baby)
- Chuck (Albert II)

Team 6:
- Audrey (Astronaut)
- Tom (East India Trading Company)
- Scollingwood (Carlton)

Note: Finn is also serving as Caddy


In [3]:
import random
from typing import List, Dict, Tuple

def group_remaining_players(remaining_males: List[Dict], remaining_females: List[Dict]) -> List[List[Dict]]:
    """
    Group remaining players into teams of 2 or 3.
    """
    remaining_teams = []
    all_remaining = remaining_males + remaining_females
    total_remaining = len(all_remaining)
    
    if total_remaining == 0:
        return []
    elif total_remaining <= 4:
        # If 4 or fewer players, put them all in one team
        remaining_teams.append(all_remaining)
    else:
        # For 5 or more players, create teams of 3 and 2
        num_players = len(all_remaining)
        if num_players % 3 == 0:
            # If divisible by 3, create all teams of 3
            for i in range(0, num_players, 3):
                remaining_teams.append(all_remaining[i:i+3])
        else:
            # Create as many teams of 3 as possible, then one team of 2
            num_teams_of_three = num_players // 3
            players_in_teams_of_three = num_teams_of_three * 3
            
            # Create teams of 3
            for i in range(0, players_in_teams_of_three, 3):
                remaining_teams.append(all_remaining[i:i+3])
            
            # Create final team with remaining players
            if players_in_teams_of_three < num_players:
                remaining_teams.append(all_remaining[players_in_teams_of_three:])
    
    return remaining_teams

def create_teams(participants: List[Dict[str, str]], seed: int = 41) -> List[List[Dict[str, str]]]:
    """
    Create primary teams of 3 (2M+1F) and handle remaining players.
    """
    # Set random seed for reproducibility
    random.seed(seed)
    
    # Separate by gender
    males = [p for p in participants if p['gender'] == 'M']
    females = [p for p in participants if p['gender'] == 'F']
    
    # Shuffle both lists
    random.shuffle(males)
    random.shuffle(females)
    
    # Calculate number of complete teams possible
    num_possible_teams = min(len(females), len(males) // 2)
    
    primary_teams = []
    # Create complete teams (2 males + 1 female)
    for i in range(num_possible_teams):
        team = [
            females[i],  # One female
            males.pop(0),  # First male
            males.pop(0)   # Second male
        ]
        primary_teams.append(team)
    
    # Handle remaining players
    remaining_males = males
    remaining_females = females[num_possible_teams:]
    remaining_teams = group_remaining_players(remaining_males, remaining_females)
    
    return primary_teams, remaining_teams

# Create participant list with new names
participants = [
    {"name": "Ben Poswell", "gender": "M"},
    {"name": "Charlie Sengupta", "gender": "M"},
    {"name": "Charlie O'Neill", "gender": "M"},
    {"name": "Loki Bromilow", "gender": "M"},
    {"name": "Will Dixson", "gender": "M"},
    {"name": "Timothy Jackman", "gender": "M"},
    {"name": "Jack Wu", "gender": "M"},
    {"name": "Lucas Sieb", "gender": "M"},
    {"name": "Georgie Forrest", "gender": "F"},
    {"name": "Jacqui Farrell", "gender": "F"},
    {"name": "Henry Palmerlee", "gender": "M"},
    {"name": "Indigo Casablanca", "gender": "F"},
    {"name": "Audrey Elwin", "gender": "F"}
]

# Generate teams
primary_teams, remaining_teams = create_teams(participants)

# Print all teams
print("\nPrimary Teams (2M+1F):")
print("-" * 50)
for i, team in enumerate(primary_teams, 1):
    print(f"\nTeam {i}:")
    for player in team:
        print(f"- {player['name']}")

if remaining_teams:
    print("\nAdditional Teams:")
    print("-" * 50)
    for i, team in enumerate(remaining_teams, len(primary_teams) + 1):
        print(f"\nTeam {i}:")
        for player in team:
            print(f"- {player['name']}")


Primary Teams (2M+1F):
--------------------------------------------------

Team 1:
- Jacqui Farrell
- Ben Poswell
- Lucas Sieb

Team 2:
- Georgie Forrest
- Will Dixson
- Charlie O'Neill

Team 3:
- Indigo Casablanca
- Loki Bromilow
- Henry Palmerlee

Team 4:
- Audrey Elwin
- Charlie Sengupta
- Timothy Jackman

Additional Teams:
--------------------------------------------------

Team 5:
- Jack Wu
