# Visualize Data

Load and explore a segment of the protein structure dataset.

In [1]:
import json
from pathlib import Path

import numpy as np
import torch

In [2]:
DATA_PATH = Path("../data/chain_set.jsonl")

In [3]:
def load_jsonl(path: Path, limit: int = 10) -> list[dict]:
    """Load first `limit` records from a JSONL file."""
    records = []
    with open(path) as f:
        for i, line in enumerate(f):
            if i >= limit:
                break
            records.append(json.loads(line))
    return records

In [4]:
# Load a small segment of the dataset
data = load_jsonl(DATA_PATH, limit=5)
print(f"Loaded {len(data)} records")

Loaded 5 records


In [5]:
def print_structure(record, num_coords_to_show=4):
    # 1. Create a shallow copy or new dict to avoid modifying original data
    preview = {}
    
    # 2. Copy simple fields directly
    preview['name'] = record.get('name')
    preview['CATH'] = record.get('CATH')
    preview['num_chains'] = record.get('num_chains')
    
    # 3. Truncate the sequence string
    seq = record.get('seq', '')
    preview['seq'] = seq[:35] + "..." if len(seq) > 35 else seq

    # 4. Handle Coordinates: Slice the lists to show only the first few items
    preview['coords'] = {}
    if 'coords' in record:
        for atom_key, coords_list in record['coords'].items():
            # Handle Numpy arrays if your data uses them
            if isinstance(coords_list, np.ndarray):
                subset = coords_list[:num_coords_to_show].tolist()
            else:
                subset = coords_list[:num_coords_to_show]
            
            preview['coords'][atom_key] = subset

    # 5. Print formatted JSON
    print(json.dumps(preview, indent=2))

# Usage
sample = data[0]
print_structure(sample)

{
  "name": "12as.A",
  "CATH": [
    "3.30.930"
  ],
  "num_chains": 8,
  "seq": "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAP...",
  "coords": {
    "N": [
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        11.751,
        37.846,
        29.016
      ]
    ],
    "CA": [
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        12.501,
        39.048,
        28.539
      ]
    ],
    "C": [
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
      [
        13.74,
        38.628,
        27.754
      ]
    ],
    "O": [
      [
        NaN,
        NaN,
        NaN
      ],
      [
        NaN,
        NaN,
        NaN
      ],
   

In [6]:
# Inspect coordinate structure
coords = sample["coords"]
print(f"Coordinate keys: {list(coords.keys())}")
for atom, positions in coords.items():
    print(f"  {atom}: {len(positions)} positions, first 3: {positions[:3]}")

Coordinate keys: ['N', 'CA', 'C', 'O']
  N: 330 positions, first 3: [[nan, nan, nan], [nan, nan, nan], [nan, nan, nan]]
  CA: 330 positions, first 3: [[nan, nan, nan], [nan, nan, nan], [nan, nan, nan]]
  C: 330 positions, first 3: [[nan, nan, nan], [nan, nan, nan], [nan, nan, nan]]
  O: 330 positions, first 3: [[nan, nan, nan], [nan, nan, nan], [nan, nan, nan]]


## Handling Missing Coordinates

Many residues have NaN coordinates due to unresolved regions in the crystal structure. We need to create masks to handle this during training.

In [7]:
def process_coords(sample: dict) -> tuple[torch.Tensor, torch.Tensor]:
    """Extract coordinates and create a validity mask.

    Returns:
        xyz: (L, 4, 3) tensor of backbone atom coordinates (N, CA, C, O)
        mask: (L,) boolean tensor, True where CA atom is resolved
    """
    coords_dict = sample["coords"]
    L = len(sample["seq"])

    # Stack coordinates: (L, 4, 3)
    xyz = np.full((L, 4, 3), np.nan, dtype=np.float32)
    for i, atom_name in enumerate(["N", "CA", "C", "O"]):
        if atom_name in coords_dict:
            xyz[:, i, :] = coords_dict[atom_name]

    # Mask based on CA atom (alpha carbon)
    ca_coords = xyz[:, 1, :]  # (L, 3)
    mask = np.isfinite(ca_coords).all(axis=-1)  # (L,)

    return torch.from_numpy(xyz), torch.from_numpy(mask)

In [8]:
# Process the sample
xyz, mask = process_coords(sample)

print(f"Coordinates shape: {xyz.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Valid residues: {mask.sum().item()} / {len(mask)} ({100 * mask.float().mean():.1f}%)")
print(f"First 10 mask values: {mask[:10].tolist()}")

Coordinates shape: torch.Size([330, 4, 3])
Mask shape: torch.Size([330])
Valid residues: 327 / 330 (99.1%)
First 10 mask values: [False, False, False, True, True, True, True, True, True, True]


In [9]:
DATA_PATH = Path("../data/chain_set.jsonl")

def filter_perfect_samples(path: Path) -> list[dict]:
    """Scans the dataset and returns a list of fully resolved samples."""
    perfect_samples = []  # <--- Initialize the storage variable
    total_count = 0
    
    print(f"Scanning {path}...")
    
    with open(path, 'r') as f:
        for i, line in enumerate(f):
            record = json.loads(line)
            
            # Run existing processing logic
            _, mask = process_coords(record)
            
            # Check if ALL residues are valid
            if mask.all():
                perfect_samples.append(record)  # <--- Store the clean record
            
            total_count += 1
            
            if total_count % 1000 == 0:
                print(f"Processed {total_count} records... (Found {len(perfect_samples)} perfect)", end="\r")

    # Final Stats
    print(f"\n{'='*30}")
    print(f"Total Scanned:   {total_count}")
    print(f"Perfect Kept:    {len(perfect_samples)}")
    print(f"Yield:           {100 * len(perfect_samples) / total_count:.2f}%")
    print(f"{'='*30}")
    
    return perfect_samples

# Run the analysis and capture the variable
clean_data = filter_perfect_samples(DATA_PATH)

# Verify the result
print(f"Variable 'clean_data' now contains {len(clean_data)} records.")
print(f"First record name: {clean_data[0]['name']}")

Scanning ../data/chain_set.jsonl...
Processed 21000 records... (Found 4895 perfect)
Total Scanned:   21668
Perfect Kept:    4976
Yield:           22.96%
Variable 'clean_data' now contains 4976 records.
First record name: 132l.A


In [10]:
clean_data[0]

{'seq': 'KVFGRCELAAAMKRHGLDNYRGYSLGNWVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCAKKIVSDGNGMNAWVAWRNRCKGTDVQAWIRGCRL',
 'coords': {'N': [[-9.649, 18.097, 49.778],
   [-8.277, 15.485, 47.81],
   [-5.26, 14.268, 47.104],
   [-3.409, 12.981, 44.317],
   [-0.965, 10.363, 44.128],
   [1.596, 10.914, 42.651],
   [1.024, 13.537, 42.118],
   [0.983, 14.691, 44.672],
   [3.629, 14.09, 45.377],
   [5.039, 15.706, 43.563],
   [4.27, 18.324, 44.372],
   [5.255, 18.295, 46.973],
   [7.904, 18.075, 46.471],
   [8.381, 20.545, 45.208],
   [8.139, 22.172, 47.404],
   [10.38, 21.795, 48.747],
   [10.893, 19.332, 50.268],
   [12.283, 17.053, 49.557],
   [14.792, 15.068, 50.662],
   [15.033, 16.719, 52.868],
   [16.523, 17.432, 55.987],
   [17.676, 14.965, 55.659],
   [15.757, 13.114, 54.842],
   [14.555, 11.329, 52.229],
   [12.26, 12.103, 49.52],
   [10.191, 10.018, 49.122],
   [9.704, 9.139, 51.683],
   [8.045, 11.132, 52.874],
   [5.779, 11.048, 51.201],
   [4.847, 8.469, 51.614

In [11]:
from pathlib import Path
from src.data_cath import get_one_chain
from src.viz import show_ca_trace

# Load a protein
name, seq, ca_coords = get_one_chain(Path("../data/chain_set.jsonl"))
print(f"Loaded {name}, {len(seq)} residues")

Loaded 132l.A, 129 residues


In [14]:
import py3Dmol
from src.pdb_io import ca_to_pdb_str

# Use the protein we already loaded (132l.A, 129 residues)
pdb_str = ca_to_pdb_str(ca_coords, seq=seq, name=name)

# Render with py3Dmol
view = py3Dmol.view(width=600, height=400)
view.addModel(pdb_str, "pdb")
view.setStyle({"sphere": {"scale": 0.5}})
view.zoomTo()
view.show()