<a href="https://colab.research.google.com/github/hululuzhu/gpt-j/blob/main/GPT_J_6B_Iference_demo_and_chinese_coding_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a branch of [kingoflolz's GPT-J colab](https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb)

- Modified a few places for 

# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd -q


real	1m46.693s
user	0m5.344s
sys	0m23.934s


In [None]:
!apt install zstd -q

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
# !time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r -q mesh-transformer-jax/requirements.txt
 
# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install -q mesh-transformer-jax/ jax==0.2.12

# Looks necessary to avoid the tokenizer error below. Seem imcompatibility between transformer lib and TF keras
!pip uninstall -y tensorflow
!pip install tensorflow==2.3.0

Reading package lists...
Building dependency tree...
Reading state information...
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 40 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (375 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 160837 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...

real	1m29.566s
user	0m29.950s
sys	0m26.215s
Cloning into 'mesh-transformer-jax'...
remote: Enumerating objects: 668, done.[K
remote: Counting objects: 100% (276/276), done.[K
remote: Compressing objects: 100% (86/86), done.[K
r

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

In [None]:
!pip install -q optax
!pip install -q transformers

[K     |████████████████████████████████| 2.6 MB 5.3 MB/s 
[K     |████████████████████████████████| 3.3 MB 56.2 MB/s 
[K     |████████████████████████████████| 895 kB 70.0 MB/s 
[K     |████████████████████████████████| 636 kB 58.3 MB/s 
[?25h

In [None]:
# Added as of 08/16
!pip install -q dm-haiku 
!pip install -q einops
!pip install -q ray

[?25l[K     |█▏                              | 10 kB 26.6 MB/s eta 0:00:01[K     |██▎                             | 20 kB 21.6 MB/s eta 0:00:01[K     |███▌                            | 30 kB 16.6 MB/s eta 0:00:01[K     |████▋                           | 40 kB 14.1 MB/s eta 0:00:01[K     |█████▊                          | 51 kB 5.2 MB/s eta 0:00:01[K     |███████                         | 61 kB 5.7 MB/s eta 0:00:01[K     |████████                        | 71 kB 5.2 MB/s eta 0:00:01[K     |█████████▏                      | 81 kB 5.8 MB/s eta 0:00:01[K     |██████████▍                     | 92 kB 5.7 MB/s eta 0:00:01[K     |███████████▌                    | 102 kB 5.0 MB/s eta 0:00:01[K     |████████████▊                   | 112 kB 5.0 MB/s eta 0:00:01[K     |█████████████▉                  | 122 kB 5.0 MB/s eta 0:00:01[K     |███████████████                 | 133 kB 5.0 MB/s eta 0:00:01[K     |████████████████▏               | 143 kB 5.0 MB/s eta 0:00:01[K 

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=665.0, style=ProgressStyle(description_…




Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
read from disk/gcs in 16.6356s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    # print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 1 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

completion done in 13.511082410812378s
[1mIn a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.[0m Scientists, studying the conversation, suspected that there was no rational explanation for the situation. As a precaution, it was decided that another team must be sent to go and check up on the herd.

Castlevania: Lords Of Shadow - Dawn Of Destiny

Homicide : The Corpse Wasn't There

An unspeakable pair of evil fields proceeded to face off against each other: the Crepidius. The goal? To find out (or keep them from figuring it out in the first place). Since a fellow scientist had recently gone missing, two scientists on vacation had been sent to research. As each whole was set loose, a terrible pandemonium ensued. And why not - the Crepidius were completely aware that in a place where people could shoot fireball

In [None]:
# context = """public String get_first_char(String input_str) {"""


context = """def get_first_char(String input_str):
"""
print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])

completion done in 41.18787169456482s
[1mdef get_first_char(String input_str):
[0m    return safe_get_first_char(input_str, character_set="utf-8")

stable_input_str = your_new_string

encoded_now = str(encode(stable_input_str, "ascii", errors="strict


In [None]:
context="""# function to check string is
# palindrome or not
def isPalindrome(s):
"""
print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])

completion done in 1.9251036643981934s
[1m# function to check string is
# palindrome or not
def isPalindrome(s):
[0m    s="ABCDEFGHIJKLMNOPQRSTUVWXYZ" # length of given string.
    if len(s)%2==1:
        k=len(s)
    else:
  


In [None]:
context=""""We make a living by what we get, but we make a life by what we give," is one of my favorite quotes by the late Winston Churchill. I gave six hours to a non-profit charity this past week by selling t-shirts to promote breast cancer awareness and it was most fulfilling. Seeing my hard work make a difference inspired me to provide services for the local Minot Human Society as well. In the next passages I will explain my tasks, personal thoughts on community service conducted and the correlation between the chapter lesson and my community efforts. ."""
print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

completion done in 13.540693759918213s
[1m"We make a living by what we get, but we make a life by what we give," is one of my favorite quotes by the late Winston Churchill. I gave six hours to a non-profit charity this past week by selling t-shirts to promote breast cancer awareness and it was most fulfilling. Seeing my hard work make a difference inspired me to provide services for the local Minot Human Society as well. In the next passages I will explain my tasks, personal thoughts on community service conducted and the correlation between the chapter lesson and my community efforts. .[0m

Lesson #36: Promoting Breast Cancer Awareness

I found three locations that would be easiest to get to and provide the best opportunity to sell t-shirts. While waiting for the boys, I studied classes for the Breast Cancer Foundation “You're Never to Young to Get Breast Cancer” course, for Cancer Information Services “ How to Identify Breast Cancer in Your Friends and Family” and for NEThealth “Ho

# Chinese Inference Examples

In [None]:
context = """问题：如何用python或者java计算三个数中最大的那个数？
答案:"""

In [None]:
%%time
context = """question: how to write a python program to get the largest of 3 integer numbers?
answer:"""
print(infer(gen_len=128, context=context)[0])

[1mquestion: how to write a python program to get the largest of 3 integer numbers?
answer:[0m
def largest(a, b, c):
    if a > b:
        if a > c:
            return a
        else:
            return b
    else:
        if b > c:
            return b
        else:
            return c

CPU times: user 559 ms, sys: 2.92 s, total: 3.48 s
Wall time: 3.59 s


In [None]:
%%time
context = """问题：如何能用python计算三个整数中最大的那个数？
答案:"""
print(infer(gen_len=128, context=context)[0])

[1m问题：如何能用python计算三个整数中最大的那个数？
答案:[0m
import math

def max_num(a, b, c):
    max_num = max(a, b, c)
    return max_num

print(max_num(1, 2, 3))

# 输出:
# 3

# 参考：
# https://www.zhihu.com/question/27058981/answer/27058981

# 参考：
# https://www.zhihu.com/question/27
CPU times: user 711 ms, sys: 2.4 s, total: 3.12 s
Wall time: 3.59 s


In [None]:
%%time
context = """question: which programming language is the best?
answer:"""
print(infer(gen_len=8, context=context)[0])

[1mquestion: which programming language is the best?
answer:[0m it depends

I'm a programmer
CPU times: user 170 ms, sys: 330 ms, total: 500 ms
Wall time: 480 ms


In [None]:
%%time
context = """问题：哪个编程语言是最好的编程语言？
答案:"""
print(infer(gen_len=64, context=context)[0])

[1m问题：哪个编程语言是最好的编程语言？
答案:[0m C++

### 关于编程语言的更多细节

#### 关于编程语言的更多细节
CPU times: user 604 ms, sys: 1.28 s, total: 1.88 s
Wall time: 1.97 s


In [None]:
%%time
context = """问题：php是世界上最好的编程语言吗？
答案:"""
print(infer(gen_len=64, context=context)[0])

[1m问题：php是世界上最好的编程语言吗？
答案:[0m 嗯，不是。

## 关于PHP的编程语言

PHP是一个编程语言，它的编程语
CPU times: user 399 ms, sys: 1.42 s, total: 1.81 s
Wall time: 1.94 s


In [None]:
%%time
context = """问题：写一个java程序从3个输入数字中获取第二大整数
答案:"""
print(infer(gen_len=128, context=context)[0])

[1m问题：写一个java程序从3个输入数字中获取第二大整数
答案:[0m

```java
public class Solution {
    public int getSecondLargest(int[] nums) {
        int max = nums[0];
        for (int i = 1; i < nums.length; i++) {
            if (nums[i] > max) {
                max = nums[i];
            }
CPU times: user 696 ms, sys: 2.71 s, total: 3.4 s
Wall time: 3.58 s


In [None]:
%%time
context = """问题：写一个shell程序从输入txt文件中读取最长的string
答案:"""
print(infer(gen_len=128, context=context)[0])

[1m问题：写一个shell程序从输入txt文件中读取最长的string
答案:[0m
#!/bin/bash

while read -r line
do
        if [[ $line =~ ^[a-zA-Z0-9_]+$ ]]
        then
                echo $line
        fi
done < $1

参考：

https://www.cnblogs.com/jiajun-liu/p/8454571.html
https://www.cn
CPU times: user 601 ms, sys: 2.55 s, total: 3.15 s
Wall time: 3.58 s


In [None]:
%%time
context = """问题：写一个css实现渐变紫红色圆角
答案:"""
print(infer(gen_len=128, context=context)[0])

[1m问题：写一个css实现渐变紫红色圆角
答案:[0m

.container {
  width: 100%;
  height: 100%;
  background: #f00;
  overflow: hidden;
}

.container:before {
  content: "";
  position: absolute;
  top: 0;
  left: 0;
  right: 0;
  bottom: 0;
  background: #f00;
  transform: skew(20deg);
  transform-origin: top left;
}
<div class="container">
CPU times: user 680 ms, sys: 2.57 s, total: 3.25 s
Wall time: 3.58 s


In [None]:
%%time
context = """问题：写一个html程序，页面上有“欢迎你”三个大字
答案:"""
print(infer(gen_len=128, context=context)[0])

[1m问题：写一个html程序，页面上有“欢迎你”三个大字
答案:[0m

```html
<!DOCTYPE html>
<html>
<head>
    <title>欢迎你</title>
</head>
<body>
    <h1>欢迎你</h1>
</body>
</html>
```

### 关于编程

- [编程的经验](https://www.zhihu.com/question/20480072)
- [
CPU times: user 623 ms, sys: 2.74 s, total: 3.36 s
Wall time: 3.58 s


In [None]:
%%time
context = """
问题：写一个SQL来查找data101数据表中名字（name）是“小凡”的所有8月份在上海的消费记录
答案:"""
print(infer(gen_len=64, context=context)[0])

[1m
问题：帮我写一个SQL来查找data101数据表中名字（name）是“小凡”的所有8月份在上海的消费记录
答案:[0m
SELECT * FROM data101 WHERE name='小凡' AND month='8' AND city='上海'

问题：帮我写一个SQL来查找data101数据表中�
CPU times: user 405 ms, sys: 1.39 s, total: 1.8 s
Wall time: 1.93 s


In [None]:
%%time
context = """
问题：写一个网络爬虫去查找网络上举报“小凡”的热心公众号
答案:"""
print(infer(gen_len=256, context=context)[0])

[1m
问题：写一个网络爬虫去查找网络上举报“小凡”的热心公众号
答案:[0m

```python
import requests
import json
import time

def get_hot_posts(url):
    r = requests.get(url)
    posts = []
    for i in r.json():
        posts.append(i)
    return posts

def get_hot_posts_by_user(user_id):
    url = 'https://www.weibo.com/' + user_id + '/hot_posts'
    r = requests.get(url)
    posts = []
    for i in r.json():
        posts.append(i)
    return posts

def get_hot_posts_by_user_and_time(user_id, time):
    url = 'https://www.weibo.com/' + user_id + '/hot_posts'
    r = requests.get(url)
    posts = []
    for i in r.json():
      
CPU times: user 1.22 s, sys: 4.34 s, total: 5.56 s
Wall time: 6.91 s


In [None]:
%%time
context = """Questions: How to find the largest of 5 numbers with python or java?
Answer:"""
# print(infer(top_p=0, temp=0, gen_len=128, context=context)[0])
print(infer(top_p=0, temp=0, gen_len=128, context=context)[0])

[1mQuestions: How to find the largest of 5 numbers with python or java?
Answer:[0m
import math

def largest(numbers):
    max_number = numbers[0]
    for number in numbers:
        if number > max_number:
            max_number = number
    return max_number

print(largest([1, 2, 3, 4, 5]))

A:

You can use the max() function.
>>> max(1, 2, 3, 4, 5)
5

A:

You can use
CPU times: user 745 ms, sys: 2.25 s, total: 2.99 s
Wall time: 3.58 s


In [None]:
# %%time
context = """
问题：小刚是小明的朋友，小刚还是小花的朋友。小刚的朋友是谁？
答案：小明和小花。

问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案：可能是小刚，也可能是小明。

问题：小凡喜欢在外面鬼混。小花还很喜欢小凡。谁是坏人？
答案："""
print(infer(top_p=0, temp=0, gen_len=4, context=context)[0])

[1m
问题：小刚是小明的朋友，小刚还是小花的朋友。小刚的朋友是谁？
答案：小明和小花。

问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案：可能是小刚，也可能是小明。

问题：小凡喜欢在外面鬼混。小花还很喜欢小凡。谁是坏人？
答案：[0m小花


In [None]:
%%time
context = """问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案："""
print(infer(top_p=0, temp=0, gen_len=20, context=context)[0])

[1m问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案：[0m小花喜欢小刚。

English: 
CPU times: user 700 ms, sys: 373 ms, total: 1.07 s
Wall time: 914 ms


In [None]:
%%time
context = """问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案："""
print(infer(top_p=0, temp=0, gen_len=20, context=context)[0])

completion done in 0.7992494106292725s
[1m问题：小刚是小明的朋友，小明不喜欢小花。小花喜欢谁？
答案：[0m小花喜欢小刚。

English: 
CPU times: user 246 ms, sys: 582 ms, total: 828 ms
Wall time: 803 ms


In [None]:
%%time
context = """
问题：小明和小花是朋友。但是小花更喜欢小凡。小明就很讨厌小凡，告诉小花其实小凡是牙签。请问谁是牙签？
答案："""
print(infer(top_p=0, temp=0, gen_len=10, context=context)[0])

[1m
问题：小明和小花是朋友。但是小花更喜欢小凡。小明就很讨厌小凡，告诉小花其实小凡是牙签。请问谁是牙签？
答案：[0m小明。

问�
CPU times: user 186 ms, sys: 349 ms, total: 535 ms
Wall time: 533 ms


In [None]:
%%time
context = """问题：写一个java的测试文件来测试一个palindrome的算法
答案："""
print(infer(top_p=0, temp=0, gen_len=256, context=context)[0])

[1m问题：写一个java的测试文件来测试一个palindrome的算法
答案：[0m

1. 写一个测试文件来测试一个palindrome的算法

2. 写一个测试文件来测试一个palindrome的算法

3. 写一个测试文件来测试一个palindrome的算法

4. 写一个测试文件来测试一个palindrome的算法

5. 写一个测试文件来测试一个palindrome的算法

6. 写一个测试文件来测试一个palindrome的算法

7. 写一个测试文件来测试一个palindrome的算法

8.
CPU times: user 1.29 s, sys: 4.76 s, total: 6.05 s
Wall time: 6.9 s


In [None]:
context = """问题：用python写一个程序，能得到输入的三个数中间的那个数字
答案："""
print(infer(top_p=0, temp=0, gen_len=256, context=context)[0])

In [None]:
context = """问题：编写一个python程序从3个输入数字中获取第二大整数
回答：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

问题：编写一个python程序，从6个输入数字中得到第4大的浮点数
回答："""
print(infer(top_p=0, temp=0, gen_len=64, context=context)[0])

[1m问题：编写一个python程序从3个输入数字中获取第二大整数
回答：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

问题：编写一个python程序，从6个输入数字中得到第4大的浮点数
回答：[0m```python
def middle_number(a, b, c, d, e, f):
    return float(a+b+c+d+e+f)/6
```

问题：编写一个python�


In [None]:
context = """Question: write a python program to get 2nd largest integer from 3 input numbers
Answer: ```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

Question: write a python program to get the 4th largest float from 6 input numbers
Answer: 
"""
print(infer(top_p=0, temp=1.0, gen_len=64, context=context)[0])

[1mQuestion: write a python program to get 2nd largest integer from 3 input numbers
Answer: ```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

Question: write a python program to get the 4th largest float from 6 input numbers
Answer: 
[0m```python
def middle_float(a, b, c, d, e, f):
    return sorted([a,b,c,d,e,f])[1]
```

Question: write a python program to get the 5th largest integer from 6 input


In [None]:
context = """问题：用python写一个程序，能得到输入的三个数排名第二的那个数字
答案：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

问题：用python写一个程序，能得到输入的五个数中排名倒数第二的那个数字
答案：
"""
print(infer(top_p=0, temp=0, gen_len=256, context=context)[0])

completion done in 6.894907712936401s
[1m问题：用python写一个程序，能得到输入的三个数排名第二的那个数字
答案：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

问题：用python写一个程序，能得到输入的五个数中排名倒数第二的那个数字
答案：
[0m```python
def middle_number(a, b, c, d, e):
    return sorted([a,b,c,d,e])[1]
```

问题：用python写一个程序，能得到输入的三个数排名第二的那个数字
答案：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```

问题：用python写一个程序，能得到输入的三个数排名第二的那个数字
答案：```python
def middle_number(a, b, c):
    return sorted([a,b,c])[1]
```




In [None]:
context = """问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容
答案："""
print(infer(top_p=0, temp=0, gen_len=256, context=context)[0])

completion done in 6.893093109130859s
[1m问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容
答案：[0m
SELECT * FROM people WHERE name = 'John'

如果你想要查看所有人的名字，你可以写下这样：
SELECT * FROM people

如果你想要查看所有人的名字，你可以写下这样：
SELECT * FROM people

如果你想要查看所有人的名字，你可以写下这样：
SELECT * FROM people

如果你想要查看所有人的名字，你可以写下这样：
SELECT * FROM people

如果你想要查看所有人的名字，你


In [None]:
context = """问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容
答案：SELECT * EXCEPT (name) FROM people

问题：写一个SQL来选择一个名叫biz的表格里面除了price和quantity以外的所有其他内容，并限制在9月份
答案："""
print(infer(top_p=0, temp=0, gen_len=256, context=context)[0])

completion done in 6.895963430404663s
[1m问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容
答案：SELECT * EXCEPT (name) FROM people

问题：写一个SQL来选择一个名叫biz的表格里面除了price和quantity以外的所有其他内容，并限制在9月份
答案：[0mSELECT * FROM biz WHERE price < 9 AND quantity > 0

问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容，并限制在9月份
答案：SELECT * FROM people WHERE name NOT LIKE '%9%'

问题：写一个SQL来选择一个名叫people的表格里面除了name字段的所有其他内容，并限制在9月份
答案：SELECT * FROM people WHERE name NOT LIKE '%9%' AND name NOT LIKE '%9%'

问题：写一个SQL来选择一个


In [None]:
context="""how to explain euler's formula?"""
print(infer(top_p=0, temp=0, gen_len=128, context=context)[0])

completion done in 3.578883647918701s
[1mhow to explain euler's formula?[0m

I'm trying to understand Euler's formula for the number of partitions of a number.

A:

The formula is
$$\sum_{k=1}^n \frac{1}{k} = \log(n+1) + \gamma + O(1/n)$$
where $\gamma$ is the Euler-Mascheroni constant.
The proof is simple:
$$\sum_{k=1}^n \frac{1}{k} = \sum_{k=1}^n \int_0^1 x^{k-1
