In [31]:
%matplotlib inline

In [2]:
import pandas as pd
from pandas import option_context

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

from scipy import stats

In [5]:
import os
import utils
import torch

import sklearn
from sklearn import metrics

import seaborn as sns
import geopandas as gpd

In [1]:
import probe

# there are some folders of the GitHub code built as libraries

<p style="font-size: 20.5px; text-align: center"> Language Models Represent Space and Time

<p style="font-size: 15px; text-align: center"> Wes Gurnee and Max Tegmark, Massachusetts Institute of Technology

<p style="font-size: 20; text-align: center; color: gray"> Resume by Irina Nedyalkova, Deep Learning Student

Max Tegmark and Wes Gurnee discover that $Large$ $Language$ $Models$ (LLMs) learn linear representations of space and time across multiple scales. The two analyze learned representations of three spatial datasets - World, US and NYC places, and three temporal datasets - historical figures, artworks and news headlines, in the Llama-2 family of models. In addition, they indentify individual "space neurons" and "time neurons" that reliably encode spatial and temporal coordinates. The result of their paper suggests that modern LLMs learn rich spatio-temporal representations of the real world and possess basic ingredients of a world model.

Despite being trained to just predict the next token, Large Language Models ($LLMs$) have demonstrated an impressive set of capabilities, raising questions about what such models have actually learned. One hypothesis is that LLMs learn a massive collection of correlations but lack any understanding of the data. An alternative hypothesis is that LLMs, in the course of compressing the data, learn more compact, coherent and interpretable models of the generative process underlying the training data. 

In this work, Gurnee and Tegmark, raise the question of whether LLMs form World (and temporal) Models as literally as possible - they attempt to extract an actual map of the world. While such spatio-temporal representations do not constitute a dynamic causal World Model in their own right, having coherent multi-scale representations of space and time are basic ingredients required in a more comprehensive model.

There are six datasets containing the names of places or events with corresponding space or time coordinates that span multiple spatio-temporal scales:
- locations within the whole world in addition to the death year of historical figures from the past 3000 years;
- the release date of art and entertainment from 1950s onward;
- the publication date of news headlines from 2010 to 2020.

Using Llama-2 family of models they train linear regression probes on the internal activations of the names of these places and events at each layer to predict the real-world location (latitude/longitude) or time (numeric timestamp). These probing experiments reveal evidence that models build spatial and temporal representations throughout the early layers before plateauing at around the model halfway point with larger models consistently outperforming smaller ones. These representations are linear, given that nonlinear probes do not perform better, fairly robust to changes in prompting and unified across different kinds of entities (cities or landmarks). Finally, the probes are used to find individual neurons which activate as a function of space or time, providing strong evidence that the model is truly using these features.

Here are entity counts and representative examples for each dataset:

In [3]:
pd.set_option("display.max_colwidth", None)      # shows whole sentence in the column

pd.read_excel("a1 datasets example.xlsx")

Unnamed: 0,Dataset,Count,Examples
0,World,39585,"""Los Angeles"", ""St.Peter's Basilica"", ""Canary Islands"""
1,USA,29997,"""Fenway Park"", ""Columbia University"", ""Riverside County"""
2,NYC,19838,"""Borden Avenue Bridge"", ""Trump International Hotel"""
3,Figures,37539,"""Cleopatra"", ""Dante Aleghieri"", ""Carl Sagan"""
4,Artworks,31321,"""Stephen King's It"", ""Queens Bohemian Rhapsody"""
5,Headlines,28389,"""Pilgrims, Fewer and Socially Distanced, Arrived in Mecca for Annual Hajj"""


All the experiments are run with the base Llama-2 series of auto-regressive transformer language models, spanning 7 billion to 70 billion parameters. For each dataset is ran every entity name through the model, potentially prepended with a short prompt and saved the activations of the hidden state (residual stream - the aggregation of the outputs from prior layers) on the last entity token for each layer.

To find evidence of spatial and temporal representations in LLMs, Gurnee and Tegmark use standard technique of probing, which fits a simple model on the network activations to predict some target label associated with labeled input data - given an activation dataset and a target (containing either the time or two-dimensional latitude and longitude coordinates) fit linear ridge regression probes yielding a linear predictor. High predictive performance on out-of-sample data indicates that the base model has temporal and spatial information linearly decodable in its representations (although this does not imply that the model actually uses these representations).

And to evaluate the performance of the probes they report standard regression metrics such as $R^2$ and Spearman rank correlation on the test data (correlations averaged over latitude and longitude for spatial features). An additional metric they compute is the proximity error for each prediction, defined as the fraction of entities predicted to be closer to the target point than the prediction of the target entity. The intuition is that for spatial data, absolute error metrics can be misleading - a 500 km error for a city on the East Coast of the United States is far more significant than a 500 km error in
Siberia.

In [29]:
world_places = pd.read_csv("world_places.csv")

In [7]:
us_places = pd.read_csv("us_places.csv")

In [8]:
nyc_places = pd.read_csv("nyc_places.csv")

In [9]:
figures = pd.read_csv("historical_figures.csv")

In [10]:
artworks = pd.read_csv("artworks.csv")

In [11]:
headlines = pd.read_csv("headlines.csv")

In [12]:
#some source code here

Do models represent space and time at all? If so, where internally in the model? Does the representation quality change substantially with model scale?

Both, spatial and temporal features, can be recovered with a linear probe. These representations smoothly increase in quality throughout the first half of the layers of the model before reaching a plateau and the representations are more accurate with increasing model scale. The dataset with the worst performance is the New York City dataset but this is also the dataset where the largest model has the best relative performance, suggesting that sufficiently large LLMs could eventually form detailed spatial models of individual cities.

Here I display Out-of-sample $R^2$ of linear and nonlinear (one layer MLP) probes for all models and
features at 60% layer depth:

In [30]:
pd.read_excel("a1 table2.xlsx")

Unnamed: 0,Model,Probe,World,USA,NYC,Figures,Artworks,Headlines
0,Llama-2-7b,Linear,0.881,0.799,0.219,0.785,0.788,0.564
1,Llama-2-7b,MLP,0.897,0.819,0.204,0.775,0.746,0.467
2,Llama-2-13b,Linear,0.896,0.825,0.237,0.804,0.806,0.645
3,Llama-2-13b,MLP,0.916,0.824,0.23,0.818,0.808,0.656
4,Llama-2-70b,Linear,0.911,0.864,0.359,0.835,0.885,0.746
5,Llama-2-70b,MLP,0.926,0.869,0.312,0.839,0.884,0.739


To test whether spatial and temporal features are represented linearly, the two compared the performance of the linear ridge regression probes with that of substantially more expressive nonlinear MLP probes of the form $W_2ReLU(W_1x + b_1) + b_2$ with 256 neurons. The table above reports the results and shows that using nonlinear probes results in minimal improvement to $R^2$ for any dataset or model. We all take this as strong evidence that space and time are also represented linearly (or at the very least are linearly decodable), despite being continuous.

to be continued...

________________________________________________________________________________________________________________________________

Original paper: https://arxiv.org/pdf/2310.02207v2.pdf
Source code: https://github.com/wesg52/world-models/tree/main